## Create training data

In [None]:
## Crop 64*64 patches, which will be used during training
import os
from PIL import Image

input_folder = "./data/PlantSR_dataset/train"
output_folder = "./data/PlantSR_dataset/train_64"


window_size = (64, 64)
stride = (32, 32) 

os.makedirs(output_folder, exist_ok=True)

for root, _, files in os.walk(input_folder):
    for file_name in files:
        img_path = os.path.join(root, file_name)
        high_res_img = Image.open(img_path)
        
        img_width, img_height = high_res_img.size

        for y in range(0, img_height - window_size[1] + 1, stride[1]):
            for x in range(0, img_width - window_size[0] + 1, stride[0]):
                cropped_img = high_res_img.crop((x, y, x + window_size[0], y + window_size[1]))
                output_file_name = f"{file_name}_crop_{x}_{y}.png" 
                output_file_path = os.path.join(output_folder, output_file_name)
                cropped_img.save(output_file_path)

## load data

In [1]:
import torch
import torch.nn as nn
from torch import nn,optim
from torch.backends import cudnn
from torch.utils.data.dataloader import DataLoader
from math import sqrt
import os
import copy
import numpy as np
from torch.utils.data import Dataset

In [2]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import random

class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        h -= h % self.upscale_factor
        w -= w % self.upscale_factor
        hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)


        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)



        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

## train

In [3]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale= 2
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

device = 'cuda:0'

In [5]:
## upscale = 2/3/4
from PlantSR import PlantSR

outPath = "outputs"
lr = 1e-4  

device = torch.device('cuda:0')
if upscale == 4:
    model = PlantSR(scale=upscale,num_features=96,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 2:
    model = PlantSR(scale=upscale,num_features=32,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 3:
    model = PlantSR(scale=upscale,num_features=64,n_resgroups=16,n_resblocks=4,reduction=16)
model.to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(),lr=lr)

In [6]:
# load the pretrained model (if have one)
model_path = 'ckpts/PlantSR_x2_best.pth'
model.load_state_dict(torch.load(model_path), strict=True)

<All keys matched successfully>

In [None]:
from tqdm import tqdm
import sys


start_epoch = 0
num_epochs = 15

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()
        # if(batch_idx%10000) == 0:
        #     torch.save(model.state_dict(), 'outputs/PlantSR_x2_{}_{}.pth'.format(batch_idx,epoch+1))

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/PlantSR_x2_{}.pth'.format(epoch+1))
    

Epoch [1/15], Batch [22934/38997], Loss: 0.0151

## Test

In [None]:
from PlantSR import PlantSR
# from PLtest import PlantSR
import torch

upscale = 2
device = 'cuda'
# model_path = r'outputs/PlantSR_x4_20000_10.pth'
model_path = r'outputs/PlantSR_x2_1.pth'
device = torch.device('cuda:0')
if upscale == 4:
    model = PlantSR(scale=upscale,num_features=96,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 2:
    model = PlantSR(scale=upscale,num_features=32,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 3:
    model = PlantSR(scale=upscale,num_features=64,n_resgroups=16,n_resblocks=4,reduction=16)

model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [None]:
import cv2 as cv2
import numpy as np
import torch.nn.functional as F
from calulate_psnr_ssim import *
import os

test_psnr = 0
test_ssim = 0
image_count = 0

test_path = "./data/PlantSR_dataset/test"

for filename in os.listdir(test_path):
    if filename.endswith((".png",".jpg")):
        image_count+=1
        print(image_count)
        file_path = os.path.join(test_path, filename)
        
        hr_img = cv2.imread(file_path, cv2.IMREAD_COLOR).astype(np.float32) 
        h, w, _ = hr_img.shape

        ## make sure same dimension
        h -= h % upscale
        w -= w % upscale
        hr_img = hr_img[:h, :w]
        
        lr_image = cv2.resize(hr_img, (w // upscale, h // upscale), interpolation=cv2.INTER_LINEAR) 
        lr_image = lr_image/255.
        lr_image = torch.from_numpy(np.transpose(lr_image[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        lr_image = lr_image.unsqueeze(0).to(device)
        
        
        with torch.no_grad():
            output = model(lr_image)

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0)

        # if (output.shape[2] != hr_img.shape[2]) or (output.shape[3] != hr_img.shape[3]):
        #     output = F.interpolate(output, size=(h, w), mode='bilinear', align_corners=False)
            
        psnr = calc_psnr(hr_img, output)
        ssim = calc_ssim(hr_img, output)
        test_psnr += psnr
        test_ssim += ssim

test_psnr = test_psnr/image_count
test_ssim = test_ssim/image_count

In [None]:
print('test psnr: {:.2f}'.format(test_psnr))
print('test ssim: {:.4f}'.format(test_ssim))

## Inference

In [None]:
from PlantSR import PlantSR
import torch

upscale = 4
device = 'cuda'
model_path = r'outputs/PlantSR_x4_best.pth'

device = torch.device('cuda:0')
if upscale == 4:
    model = PlantSR(scale=upscale,num_features=96,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 2:
    model = PlantSR(scale=upscale,num_features=32,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 3:
    model = PlantSR(scale=upscale,num_features=64,n_resgroups=16,n_resblocks=4,reduction=16)

model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

In [None]:
import os
import cv2
import numpy as np

input_folder = "data/PlantSR_dataset/YourData"
output_folder = "data/PlantSR_dataset/YourDatax2"

os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(input_folder):
    if filename.endswith((".jpg", ".jpeg", ".png")):
        img_path = os.path.join(input_folder, filename)

        img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) /255.
        h, w, _ = img.shape
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img)

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round().astype(np.uint8)

        save_path = os.path.join(output_folder, filename)

        cv2.imwrite(save_path, output)