In [76]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import cv2
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.models import resnet18, resnet50
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from SSIM_PIL import compare_ssim
from PIL import Image

In [None]:
pth = r"path to data"

In [28]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device:{device}")

Device:cuda


In [29]:
class dataset:
    def __init__(self, path):
        self.pth = path
        
        self.hr_pth = os.path.join(self.pth, "HR")
        self.lr_pth =os.path.join(self.pth, "LR")
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])
    def __len__(self):
        return len(self.hr_pth)
    def __getitem__(self, idx):
        hr_pth = os.path.join(self.hr_pth,os.listdir(self.hr_pth)[idx])
        lr_pth = os.path.join(self.lr_pth, os.listdir(self.lr_pth)[idx])
        hr_img, lr_img= np.load(hr_pth), np.load(lr_pth)
        hr_img = torch.tensor(hr_img, dtype=torch.float32)
        lr_img= torch.tensor(lr_img, dtype=torch.float32)
        hr_img= self.transform(hr_img)
        lr_img= self.transform(lr_img)
        return hr_img, lr_img

In [30]:
data= dataset(pth)
batch_size= 32

In [31]:
train_size = int(0.9 * len(data))
test_size = len(data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

In [32]:
train_loader= DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader= DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [33]:
class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net= nn.Sequential(
            nn.Conv2d(1, 64,kernel_size= 9, padding = 4, stride= 1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size= 1, padding = 0, stride= 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size= 5, padding = 2, stride= 1)
        )
    def forward(self, x):
        return self.net(x)

In [34]:
def upscale_img(img):
    img= F.interpolate(img, img.shape[2]*2, mode='bilinear', align_corners=False)
    return img

In [43]:
x= train_dataset[0][1]
x.shape

torch.Size([1, 75, 75])

In [35]:
model = SRCNN().to(device)
epochs = 20
lr1= 1e-4
lr2= 1e-5
loss_fn = nn.MSELoss().to(device)
optimizer= torch.optim.SGD(
    [
        {'params': model.net[0].parameters(), 'lr': lr1},
        {'params': model.net[2].parameters(), "lr":lr1},
        {'params': model.net[4].parameters(), "lr":lr2}    
    ]
)

In [None]:
weight_pth= r"path to weights"

In [37]:
model.load_state_dict(torch.load(weight_pth, weights_only=True))

<All keys matched successfully>

In [44]:
for epoch in range(epochs):
    print(f"epoch: {epoch+1}")
    for batch, (hr, lr) in enumerate(train_loader):
        lr= lr.to(device)
        hr= hr.to(device)
        upscaled= upscale_img(lr)
        y_cap= model(upscaled)
        loss= loss_fn(y_cap, hr)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch%100==0:
            print(f"Batch:{batch} | Loss:{loss.item()}")

epoch: 1
Batch:0 | Loss:0.017433471977710724
epoch: 2
Batch:0 | Loss:0.01658250391483307
epoch: 3
Batch:0 | Loss:0.014951993711292744
epoch: 4
Batch:0 | Loss:0.017611544579267502
epoch: 5
Batch:0 | Loss:0.016232216730713844
epoch: 6
Batch:0 | Loss:0.015912851318717003
epoch: 7
Batch:0 | Loss:0.014369502663612366
epoch: 8
Batch:0 | Loss:0.01652359776198864
epoch: 9
Batch:0 | Loss:0.015605289489030838
epoch: 10
Batch:0 | Loss:0.016770128160715103
epoch: 11
Batch:0 | Loss:0.015839798375964165
epoch: 12
Batch:0 | Loss:0.01557568646967411
epoch: 13
Batch:0 | Loss:0.015284288674592972
epoch: 14
Batch:0 | Loss:0.016293812543153763
epoch: 15
Batch:0 | Loss:0.01557729672640562
epoch: 16
Batch:0 | Loss:0.017661597579717636
epoch: 17
Batch:0 | Loss:0.01477002538740635
epoch: 18
Batch:0 | Loss:0.015736624598503113
epoch: 19
Batch:0 | Loss:0.016701694577932358
epoch: 20
Batch:0 | Loss:0.01566552370786667


In [109]:
def cvtToImage(img):
    image= torch.permute(img, (1,2,0))
    image=image.squeeze(-1)
    image= np.asarray((image+1)*127.5, dtype=np.uint8)#.astype(np.uint8)
    return image
def Psnr(mse):
    return 10*torch.log10(255**2/ torch.tensor(mse))

In [47]:
size= len(test_loader)
test_loss= 0
psnr=0
final_img= []
hr_img=[]
with torch.no_grad():
    for hr,lr in test_loader:
        lr= lr.to(device)
        hr= hr.to(device)
        upscaled= upscale_img(lr)
        sr= model(upscaled)
        test_loss+=loss_fn(sr, hr).item()
        psnr+= Psnr(test_loss)
        final_img.append(sr)
        hr_img.append(hr)
psnr/= size
test_loss/=size
print(f"Loss: {test_loss} | PSNR: {psnr}")

Loss: 0.018980421125888824 | PSNR: 65.34774017333984


In [48]:
sr_img= []
GT_img= []
for i in final_img:
    for j in i:
       sr_img.append(j)

for i in hr_img:
    for j in i:
      GT_img.append(j)

img_a= sr_img[1].cpu().squeeze(0)
img_b= GT_img[1].cpu().squeeze(0)

In [118]:
ssim =0
for sr, gt in zip(sr_img, GT_img):
    sr= sr.cpu()
    hr= gt.cpu()
    sr_pil= Image.fromarray(cvtToImage(sr))
    hr_pil = Image.fromarray(cvtToImage(hr))
    ssim+= compare_ssim(sr_pil, hr_pil, GPU= False)
ssim/= len(sr_img)
print(f"Loss: {test_loss} | PSNR: {psnr} | Average SSIM: {ssim}")

Loss: 0.018980421125888824 | PSNR: 65.34774017333984 | Average SSIM: 0.8911372193576739
