## Preparing

In [1]:
import torch
import numpy as np
import cv2
import sys
sys.path.append('../')
from models.network_swinir import SwinIR as net

In [2]:
# def define_model(task,model_path,scale):

#     # 001 classical image sr
#     if task == 'classical_sr':
#         model = net(upscale=scale, in_chans=3, img_size=48, window_size=8,
#                     img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
#                     mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
#         param_key_g = 'params'

#     elif task == 'lightweight_sr':
#         # 002 lightweight image sr
#         # use 'pixelshuffledirect' to save parameters
#         model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
#                     img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
#                     mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
#         param_key_g = 'params'

#     pretrained_model = torch.load(model_path)
#     model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)

#     return model

In [3]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import random

class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        h -= h % self.upscale_factor
        w -= w % self.upscale_factor
        hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)


        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)



        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

In [4]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale = 2
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)


train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
device = 'cuda'

## scale = 2/3/4
model = net(upscale=upscale, in_chans=3, img_size=64, window_size=8,
            img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
            mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv').to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
## load the pretrained model (if have one)

# model_path = 'outputs/SwinIR_x2_2.pth'
# pretrained_model = torch.load(model_path)
# model.load_state_dict(pretrained_model)

<All keys matched successfully>

In [6]:
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train

In [None]:
import sys
## ignore warnings

class HiddenPrints:
    def write(self, msg):
        pass

try:
    sys.stderr = HiddenPrints()
except:
    pass




In [None]:
from tqdm import tqdm
import sys


start_epoch = 2
num_epochs = 5

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())
        
        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/SwinIR_x2_{}.pth'.format(epoch+1))


Epoch [3/5], Batch [19891/38997], Loss: 0.0107

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [3/5], Batch [38997/38997], Loss: 0.0053

Epoch [4/5], Batch [2837/38997], Loss: 0.0098

## Test

In [10]:
import torch
device = 'cuda'
model_path = r'outputs/SwinIR_1.pth'
scale = 2
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
            img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
            mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv').to(device)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [11]:
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [12]:
import cv2 as cv2
import numpy as np
import os

test_psnr = 0
interpolate_psnr = 0
image_count = 0

test_path = "data/Set14/"

for filename in os.listdir(test_path):
    if filename.endswith(".png"):

        image_count+=1
        file_path = os.path.join(test_path, filename)
        hr_img = cv2.imread(file_path, cv2.IMREAD_COLOR).astype(np.float32) 
        h, w, _ = hr_img.shape
        
        lr_image = cv2.resize(hr_img, (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) 
        lr_image = lr_image/255.
        
        lr_image = torch.from_numpy(np.transpose(lr_image[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        lr_image = lr_image.unsqueeze(0).to(device)
        
        
        with torch.no_grad():
            output = model(lr_image)
        
        hr_img = hr_img/255.
        hr_img = torch.from_numpy(np.transpose(hr_img[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        hr_img = hr_img.unsqueeze(0).to(device)
        psnr = calc_psnr(hr_img, output)
        test_psnr += psnr
        
        psnr = calc_psnr(hr_img, lr_image)
        interpolate_psnr+=psnr

test_psnr = test_psnr/image_count
interpolate_psnr = interpolate_psnr/image_count




In [13]:
print('test psnr: {:.2f}'.format(test_psnr))
print('interpolate psnr: {:.2f}'.format(interpolate_psnr))

test psnr: 29.52
interpolate psnr: 27.51
