In [None]:
#Importing packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image

import torchvision.transforms.functional as TF
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse

In [None]:
# Getting the gifs
import io
import imageio
from ipywidgets import widgets, HBox

final_targets = np.load('target_model_1.npy')
final_outputs = np.load('output_model_1.npy')

tmp = 1
fps = 20
for target, output in zip(final_targets, final_outputs):
    if tmp > 3:
        break
    target = np.array(target, dtype = 'uint8').squeeze()
    output = np.array(output, dtype = 'uint8').squeeze()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, target, "GIF", fps = fps)    
        target_gif = gif.getvalue()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, output, "GIF", fps = fps)    
        output_gif = gif.getvalue()
    
    print(f"\nTest video: {tmp}")
    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))
    
    tmp += 1

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
# JUST RUN THIS CODE. This creates an instance of the Dataset class

import cv2

final_outputs = np.load('output_model_1.npy')
final_targets = np.load('target_model_1.npy')

In [None]:
class DatasetImages(torch.utils.data.Dataset):

    def __init__(self, outputs, targets, num_rows=None):
        X_tmp = None
        y_tmp = None
        img_multiple_of = 8
        flag = False
        tmp = 1
        for target, output in zip(targets, outputs):
            output = np.array(output, dtype = 'uint8').squeeze()
            target = np.array(target, dtype = 'uint8').squeeze()

            for i in range(10):
                filename1 = str(tmp) + "_output.png"
                filename2 = str(tmp) + "_target.png"

                im = Image.fromarray(output[i])
                im.save(filename1)

                im2 = Image.fromarray(target[i])
                im2.save(filename2)

                im = cv2.cvtColor(cv2.imread(filename1), cv2.COLOR_BGR2RGB)
                input_ = torch.from_numpy(im).float().div(255.).permute(2,0,1).unsqueeze(0)#.cuda()

                h,w = input_.shape[2], input_.shape[3]
                H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
                padh = H-h if h%img_multiple_of!=0 else 0
                padw = W-w if w%img_multiple_of!=0 else 0
                input_ = F.pad(input_, (0,padw,0,padh), 'reflect')


                im2 = cv2.cvtColor(cv2.imread(filename2), cv2.COLOR_BGR2RGB)
                input_2 = torch.from_numpy(im2).float().div(255.).permute(2,0,1).unsqueeze(0)#.cuda()

                h,w = input_2.shape[2], input_2.shape[3]
                H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
                padh = H-h if h%img_multiple_of!=0 else 0
                padw = W-w if w%img_multiple_of!=0 else 0
                input_2 = F.pad(input_2, (0,padw,0,padh), 'reflect')

                tmp += 1

                if flag == False:
                    flag = True
                    self.X_data = input_
                    self.y_data = input_2
                else:
                    self.X_data = torch.cat((self.X_data, input_), axis=0)
                    self.y_data = torch.cat((self.y_data, input_2), axis=0)


        #     if flag == False:
        #         X_tmp = output
        #         y_tmp = target
        #         flag = True
        #     else:
        #         X_tmp = np.concatenate((X_tmp, output), axis=0)
        #         y_tmp = np.concatenate((y_tmp, target), axis=0)
        # self.X_data = torch.tensor(X_tmp, dtype=torch.float32)
        # self.y_data = torch.tensor(y_tmp, dtype=torch.float32)
    
    def __len__(self):
        return len(self.X_data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        outputs = self.X_data[idx]
        targets = self.y_data[idx]
        # sample = {'outputs' : outputs, 'targets' : targets}
        sample = outputs.to(device), targets.to(device)
        return sample

In [None]:
train_ds = DatasetImages(final_outputs, final_targets)

In [None]:
#Creating the dataloader
batch_size = 1
train_dl = torch.utils.data.DataLoader(train_ds, batch_size = batch_size, shuffle=True)

In [None]:
# #Cloning the Restormer repo

# import os
# !pip install einops

# if os.path.isdir('Restormer'):
#   !rm -r Restormer

# # Clone Restormer
# !git clone https://github.com/swz30/Restormer.git

In [None]:
%cd Restormer

In [None]:
task = 'Real_Denoising'
# task = 'Single_Image_Defocus_Deblurring'
# task = 'Motion_Deblurring'
# task = 'Deraining'

# Download the pre-trained models
# if task is 'Real_Denoising':
#   !wget https://github.com/swz30/Restormer/releases/download/v1.0/real_denoising.pth -P Denoising/pretrained_models
# if task is 'Single_Image_Defocus_Deblurring':
#   !wget https://github.com/swz30/Restormer/releases/download/v1.0/single_image_defocus_deblurring.pth -P Defocus_Deblurring/pretrained_models
# if task is 'Motion_Deblurring':
#   !wget https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth -P Motion_Deblurring/pretrained_models
# if task is 'Deraining':
#   !wget https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth -P Deraining/pretrained_models

In [None]:
# %cd Restormer
def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] = 'BiasFree'
    return weights, parameters


# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)

load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)
model.to(device) # model.cuda()

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
# model.eval()

In [None]:
!nvidia-smi

In [None]:
%cd ..

In [None]:
# model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum = 0.95)
# model.load_state_dict(torch.load("/home/staditya/Desktop/Pushkal/IVP Project/Image Deblurring/deblur.pth"))

In [None]:
num_epochs=50

since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    print('-' * 10)

    model.train()

    running_loss = 0.0
    running_corrects = 0
    cumm_loss = 0
    pbar = tqdm(total=len(train_dl), desc = "Training", position=0, leave=True, bar_format='{l_bar}{bar:60}{r_bar}{bar:-10b}')
    for inputs, targets in train_dl:
        # inputs = inputs.to(device)
        # targets = targets.to(device)
        # print(targets)
#         print("input shape:", inputs.shape)
        outputs = model(inputs)
#         print("output shape:", outputs.shape)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#             running_loss += loss.item() * inputs.size(0)
#             running_corrects += torch.sum(preds == targets.data)
        pbar.update(1)
        cumm_loss += loss.item()
    print(f"Loss: {cumm_loss}\n")

In [None]:
model.eval()
for k in range(100):
    i = 0
    for inputs, targets in train_dl:
        out = model(inputs).cpu().detach()
        i += 1
        if i == 1:
            break
    print(out.shape)
    plt.figure(figsize = (10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(torch.permute(inputs[0].cpu().detach(), (1, 2, 0)), cmap = "gray")
    plt.xticks([])
    plt.yticks([])
    plt.subplot(1, 3, 2)
    plt.imshow(torch.permute(out[0], (1, 2, 0)), cmap = "gray")
    plt.xticks([])
    plt.yticks([])
    plt.subplot(1, 3, 3)
    plt.imshow(targets[0, 0].cpu().detach(), cmap = "gray")
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.show()

In [None]:
model_path = os.path.join("", "denoising"+".pth")
torch.save(model.state_dict(), model_path)