In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50
import torch.nn.functional as F
import math
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as PSNR
from SSIM_PIL import compare_ssim
from PIL import Image
data_pth = r"Path"


In [2]:
device= "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device:{device}")

Device:cuda


In [3]:
lr_pth, hr_pth = os.path.join(data_pth, os.listdir(data_pth)[2]),os.path.join(data_pth, os.listdir(data_pth)[1]) 

In [4]:
def lr_hr_data(lr_pth, hr_pth):
    lr_samples= os.listdir(lr_pth)
    hr_samples= os.listdir(hr_pth)
    data= []
    for i, j in zip(lr_samples, hr_samples):
        data.append([os.path.join(lr_pth, i),os.path.join(hr_pth, j) ])
    return data
lr_hr_data= lr_hr_data(lr_pth, hr_pth)

In [5]:
def normalize(x):
        x_min, x_max= x.min(), x.max()
        x_norm= (x-x_min)/(x_max-x_min)
        return x_norm

class dataset:
    def __init__(self, data, lr_dim, hr_dim):
        self.data= data
        self.lr_dim = lr_dim
        self.hr_dim = hr_dim 
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        lr_img, hr_img= np.load(self.data[idx][0]),np.load(self.data[idx][1])
        lr_img= np.reshape(lr_img, (self.lr_dim, self.lr_dim, 1))
        hr_img= np.reshape(hr_img, (self.hr_dim, self.hr_dim, 1))
        lr_img= cv2.equalizeHist(((lr_img+1)*127.5).astype(np.uint8))
        hr_img= cv2.equalizeHist(((hr_img+1)*127.5).astype(np.uint8))
        lr_img= normalize(lr_img)
        hr_img= normalize(hr_img)
        lr_img =  np.reshape(lr_img, (1, self.lr_dim, self.lr_dim))
        lr_img = torch.tensor(lr_img, dtype = torch.float32)
        hr_img =  np.reshape(hr_img, (1, self.hr_dim, self.hr_dim))
        hr_img = torch.tensor(hr_img, dtype = torch.float32)
        return lr_img, hr_img

In [6]:
class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net= nn.Sequential(
            nn.Conv2d(1, 64,kernel_size= 9, padding = 4, stride= 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size= 1, padding = 0, stride= 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size= 5, padding = 2, stride= 1)
        )
    def forward(self, x):
        return self.net(x)

In [7]:
def upscale_img(img):
    img= F.interpolate(img, img.shape[2]*2, mode='bilinear', align_corners=False)
    return img

In [8]:
model = SRCNN().to(device)
epochs = 20
lr1= 1e-4
lr2= 1e-5
loss_fn = nn.MSELoss().to(device)
optimizer= torch.optim.SGD(
    [
        {'params': model.net[0].parameters(), 'lr': lr1},
        {'params': model.net[2].parameters(), "lr":lr1},
        {'params': model.net[4].parameters(), "lr":lr2}    
    ]
)
batch_size= 32

In [9]:
data= dataset(lr_hr_data, 75, 150)
train_data, val_data= torch.utils.data.random_split(data,[8000,2000])
train_loader= DataLoader(train_data, batch_size= batch_size, shuffle=True)
val_loader= DataLoader(val_data, batch_size, shuffle= True)

In [10]:
for epoch in range(epochs):
    print(f"epoch: {epoch+1}")
    for batch, (lr, hr) in enumerate(train_loader):
        lr= lr.to(device)
        hr= hr.to(device)
        upscaled= upscale_img(lr)
        y_cap= model(upscaled)
        loss= loss_fn(y_cap, hr)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch%100==0:
            print(f"Batch:{batch} | Loss:{loss.item()}")

epoch: 1
Batch:0 | Loss:0.4336123764514923
Batch:100 | Loss:0.41895008087158203
Batch:200 | Loss:0.4006974995136261
epoch: 2
Batch:0 | Loss:0.39290887117385864
Batch:100 | Loss:0.37684518098831177
Batch:200 | Loss:0.3556709587574005
epoch: 3
Batch:0 | Loss:0.3494839370250702
Batch:100 | Loss:0.332652747631073
Batch:200 | Loss:0.31448403000831604
epoch: 4
Batch:0 | Loss:0.30587702989578247
Batch:100 | Loss:0.2906138002872467
Batch:200 | Loss:0.2727082371711731
epoch: 5
Batch:0 | Loss:0.2697928547859192
Batch:100 | Loss:0.25124868750572205
Batch:200 | Loss:0.23619206249713898
epoch: 6
Batch:0 | Loss:0.23225603997707367
Batch:100 | Loss:0.21591514348983765
Batch:200 | Loss:0.20129628479480743
epoch: 7
Batch:0 | Loss:0.1911613494157791
Batch:100 | Loss:0.17715024948120117
Batch:200 | Loss:0.16523927450180054
epoch: 8
Batch:0 | Loss:0.15516677498817444
Batch:100 | Loss:0.1416490525007248
Batch:200 | Loss:0.12740465998649597
epoch: 9
Batch:0 | Loss:0.11900666356086731
Batch:100 | Loss:0.1062

In [11]:
def cvtToImage(img):
    image= torch.permute(img, (1,2,0))
    image= image.squeeze(-1)
    image= np.asarray((image+1)*127.5,dtype= np.uint8)
    return image
def Psnr(mse):
    return 10*torch.log10(255**2/ torch.tensor(mse))

In [12]:
size= len(val_loader)
test_loss= 0
psnr=0
final_img= []
hr_img=[]
with torch.no_grad():
    for lr,hr in val_loader:
        lr= lr.to(device)
        hr= hr.to(device)
        upscaled= upscale_img(lr)
        sr= model(upscaled)
        test_loss+=loss_fn(sr, hr).item()
        psnr+= Psnr(test_loss)
        final_img.append(sr)
        hr_img.append(hr)
psnr/= size
test_loss/=size
print(f"Loss: {test_loss} | PSNR: {psnr}")

Loss: 0.02073234266468457 | PSNR: 51.09543991088867


In [13]:
torch.save(model.state_dict(), "Srcnn_weights.pth")

In [14]:
sr_img= []
GT_img= []
for i in final_img:
    for j in i:
       sr_img.append(j)

for i in hr_img:
    for j in i:
      GT_img.append(j)

In [15]:
img_a= sr_img[1].cpu().squeeze(0)
img_b= GT_img[1].cpu().squeeze(0)

In [18]:
ssim =0
for sr, gt in zip(sr_img, GT_img):
    sr= sr.cpu()
    hr= gt.cpu()
    sr_pil= Image.fromarray(cvtToImage(sr))
    hr_pil = Image.fromarray(cvtToImage(hr))
    ssim+= compare_ssim(sr_pil, hr_pil, GPU= False)
ssim/= len(sr_img)
print(f"Loss: {test_loss} | PSNR: {psnr} | Average SSIM: {ssim}")

Loss: 0.02073234266468457 | PSNR: 51.09543991088867 | Average SSIM: 0.43656901973842915
