In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import glob
import cv2
import numpy as np
import PIL.Image as Image
import cv2
import copy
from PIL import ImageFilter
import matplotlib.pyplot as plt

In [None]:
SCALE = 3

In [None]:
class HeadDataset(Dataset):
    def __init__(self, files, scale=SCALE, stride=30):
        # self.files = files
        self.scale = scale
        # self.files = []
        image_windows = []
        for image in files:
            img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB)
            h,w,c = img.shape
            for i in range(0,h-stride,stride):
                for j in range(0,w-stride,stride):
                    image_windows.append(img[i:i+stride, j:j+stride, :])
        self.files = image_windows

    def __len__(self):
        return len(self.files)

           
    def __getitem__(self, idx):
        scale = self.scale
        hr = Image.fromarray(self.files[idx])
        hr_width = (hr.width // scale) * scale
        hr_height = (hr.height // scale) * scale
        lr = hr.filter(ImageFilter.GaussianBlur(radius=2))
        lr = lr.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
        lr = lr.resize((lr.width * scale, lr.height * scale), resample=Image.BICUBIC)
        hr = np.moveaxis(np.array(hr).astype(np.float32), 2, 0)
        lr = np.moveaxis(np.array(lr).astype(np.float32), 2, 0)
        return lr, hr

In [None]:
train_folder_regex = './Train/*'
train_dataset = HeadDataset(glob.glob(train_folder_regex))
train_data = DataLoader(dataset=train_dataset,batch_size=16,shuffle=True,num_workers=8,pin_memory=True)

In [None]:
test_data = []
for image in glob.glob('./Test/*/*.bmp'):
    img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB)
    img = torch.tensor(np.array([np.moveaxis(np.array(img).astype(np.float32), 2, 0)]))
    test_data.append(img)

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Network architecture

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, (5,5), padding=2)
        self.conv2 = nn.Conv2d(128, 64, (5,5), padding=2)
        self.conv3 = nn.Conv2d(64, 3, (5,5), padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x


class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, (5,5), padding=2)
        self.conv2 = nn.Conv2d(128, 256, (5,5), padding=2)
        self.conv3 = nn.Conv2d(256, 512, (5,5), padding=2)
        self.conv4 = nn.Conv2d(512, 256, (5,5), padding=2)
        self.conv5 = nn.Conv2d(256, 128, (5,5), padding=2)
        self.conv6 = nn.Conv2d(128, 3, (5,5), padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        return x

In [None]:
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

In [None]:
num_epochs = 20
lr = 1e-4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': lr * 0.1}
    ], lr=lr)
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

total_loss_list = []
avg_psnr_list = []
device

In [None]:
for epoch in range(num_epochs):
    model.train()
    total_error = 0
    for data in train_data:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        preds = model(inputs)

        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_error+=loss.item()
        # print("D")

    model.eval()

    total_psnr = 0
    for img in test_data:
        preds = model(img.to(device))
        if total_psnr==0:
            src = np.moveaxis(img.detach().numpy()[0],0,2).astype('uint8')
            dst = np.moveaxis(preds.cpu().detach().numpy()[0],0,2).astype('uint8')
            # plt.imshow(np.hstack((img.detach().numpy()[0], preds.cpu().detach().numpy())))
            plt.imshow(np.hstack((src,dst))), plt.axis('off'), plt.show()
        total_psnr += PSNR(img.detach().numpy(), preds.cpu().detach().numpy())

    print(epoch, total_psnr/len(test_data), total_error)