# Imports

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm

from os.path import join
from os import listdir

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, Grayscale, GaussianBlur
from PIL import Image

import matplotlib.pyplot as plt
from torchvision.utils import save_image

# Dataset

In [None]:
class TrainingDataset(Dataset):
    def __init__(self, folder, blur_size, crop):
        super(TrainingDataset, self).__init__()
        self.image_filenames = [join(folder, x) for x in listdir(folder)]

        self.hr_transform = Compose([
            RandomCrop(crop),
            Grayscale(num_output_channels=1),
            ToTensor(),
        ])

        self.lr_transform = Compose([
            ToPILImage(),
            GaussianBlur(blur_size),
            ToTensor()
        ])

    def __getitem__(self, i):
        hr_image = self.hr_transform(Image.open(self.image_filenames[i]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


In [None]:

CROP = 32
BATCH_SIZE = 100
EPOCHS = 250
BLUR = 3
LR = 0.001

train_data = TrainingDataset("data/train", crop=CROP, blur_size=BLUR)
train_loader = DataLoader(
    train_data, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)


# val_data = TrainingDataset("data/val", crop=CROP,blur_size=BLUR)
# val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)


# SRCNN Model

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9,
                               padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1,
                               padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5,
                               padding=2, padding_mode='replicate')

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)

        return x


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SRCNN().to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
loss_function = nn.MSELoss()


def train(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, (lr_image, hr_image) in tqdm(enumerate(dataloader)):
        lr_image = lr_image.to(device)
        hr_image = hr_image.to(device)

        outputs = model(lr_image)
        loss = loss_function(outputs, hr_image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    final_loss = running_loss/len(dataloader.dataset)

    return final_loss


# Train model

In [None]:
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1} of {EPOCHS}")
    train_epoch_loss = train(model, train_loader)
    print(train_epoch_loss)


# Save model

In [None]:
print('Saving model...')
torch.save(model.state_dict(), './models/SRCNN2.pth')

# Testing the trained model

In [None]:
test_data_folder = "data/testing"
image_filenames = [join(test_data_folder, x)
                   for x in listdir(test_data_folder)]

model = SRCNN()
model.load_state_dict(torch.load('./models/SRCNN.pth'))

model.eval()
with torch.no_grad():
    for filename in image_filenames:
        image = cv2.imread(filename, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        image = image.reshape(image.shape[0], image.shape[1], 1)
        image = image / 255.0

        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        image = torch.tensor(image, dtype=torch.float).to(device)
        image = image.unsqueeze(0)
        save_image(image, f"./outputs/test_{os.path.basename(filename)}")

        outputs = model(image)
        outputs = outputs.cpu()
        save_image(outputs, f"./outputs/out_{os.path.basename(filename)}")

In [None]:
torch.cuda.is_available() 