In [None]:
import sys
sys.path.append('../')

## Load data

In [1]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import random

class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        h -= h % self.upscale_factor
        w -= w % self.upscale_factor
        hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)

        ## data enhancement
        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)

        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

In [10]:
from models.EDSR import EDSR,Upsample

In [2]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from models.EDSR import EDSR,Upsample

##  upscale = 2/3/4
upscale= 3
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

device = 'cuda:0'

## EDSR_M
model = EDSR(
    num_in_ch=3, num_out_ch=3,num_feat=64,num_block=16,upscale=upscale,res_scale=1).to(device)

In [3]:
## load the pretrained model (if have one)

# model_path = r'models/EDSR/EDSRx2.pth'
# model.load_state_dict(torch.load(model_path)["params"], strict=True)

In [4]:
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train

In [None]:
import sys
## ignore warnings

class HiddenPrints:
    def write(self, msg):
        pass

try:
    sys.stderr = HiddenPrints()
except:
    pass




In [None]:
from tqdm import tqdm
import sys

start_epoch = 0
num_epochs = 10

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()

    print("\n")
    if (epoch+1) % 1 == 0:
        ## x2 x3 x4
        torch.save(model.state_dict(), 'outputs/edsr_x3_{}.pth'.format(epoch+1))


Epoch [1/10], Batch [38997/38997], Loss: 0.0125

Epoch [2/10], Batch [38997/38997], Loss: 0.0204

Epoch [3/10], Batch [38997/38997], Loss: 0.0057

Epoch [4/10], Batch [38997/38997], Loss: 0.0278

Epoch [5/10], Batch [1318/38997], Loss: 0.0203

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [8/10], Batch [38997/38997], Loss: 0.0330

Epoch [9/10], Batch [38997/38997], Loss: 0.0103

Epoch [10/10], Batch [10975/38997], Loss: 0.0205

## Test

In [5]:
from models.EDSR import EDSR
import torch
device = 'cuda'
model_path = r'outputs/edsr_30.pth'
model = EDSR(
    num_in_ch=3, num_out_ch=3,num_feat=64,num_block=16,upscale=2,res_scale=1)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [6]:
def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

import torch.nn.functional as F
def calc_ssim(img1, img2):
    return torch.mean(F.mse_loss(img1, img2))

In [9]:
import cv2 as cv2
import numpy as np
import os

test_psnr = 0
interpolate_psnr = 0
image_count = 0

test_path = "Your data Path"

for filename in os.listdir(test_path):
    if filename.endswith((".png",".jpg")):

        image_count+=1
        file_path = os.path.join(test_path, filename)
        hr_img = cv2.imread(file_path, cv2.IMREAD_COLOR).astype(np.float32) 
        h, w, _ = hr_img.shape
        
        lr_image = cv2.resize(hr_img, (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) 
        lr_image = lr_image/255.
        lr_image = torch.from_numpy(np.transpose(lr_image[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        lr_image = lr_image.unsqueeze(0).to(device)
        
        
        with torch.no_grad():
            output = model(lr_image)
        
        hr_img = hr_img/255.
        hr_img = torch.from_numpy(np.transpose(hr_img[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        hr_img = hr_img.unsqueeze(0).to(device)
        psnr = calc_psnr(hr_img, output)
        test_psnr += psnr

test_psnr = test_psnr/image_count

In [10]:
print('test psnr: {:.2f}'.format(test_psnr))
print('interpolate psnr: {:.2f}'.format(interpolate_psnr))

test psnr: 29.78
interpolate psnr: 27.51


## Predict

In [7]:
device = 'cuda:0'
model_path ='outputs/edsr_30.pth'
model = EDSR(
    num_in_ch=3, num_out_ch=3,num_feat=64,num_block=16,upscale=2,res_scale=1)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [25]:
import os
import cv2
import numpy as np

input_folder = ""
output_folder = ""

os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(input_folder):
    if filename.endswith((".jpg", ".jpeg", ".png")):
        img_path = os.path.join(input_folder, filename)

        img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img)

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round().astype(np.uint8)

        save_path = os.path.join(output_folder, filename)

        cv2.imwrite(save_path, output)

True