In [1]:
### Before start, you may need to learn something about super-resolution by collecting resources from Internet
### Super-resolution is an easy task and serves as a low-level task of computer vision
### The basic concept can be summarized in one sentence: reconstructing low resolution images into higher resolution images

## Load data    
You can also use prepare.ipynb to generate training data from original images.

In [2]:
from torch import nn,optim
from torch.backends import cudnn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import random
from math import sqrt

In [3]:
class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        h -= h % self.upscale_factor
        w -= w % self.upscale_factor
        hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)


        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)



        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

## Train


In [4]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale= 4
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)

device = 'cuda:0'

In [5]:
## This is the original useage of PlantSRv1, you need to adjust them accoarding to your PlantSRv2

## upscale = 2/3/4
from models.PlantSRv2 import PlantSR

outPath = "outputs"
lr = 1e-4  

device = torch.device('cuda:0')
if upscale == 4:
   model = PlantSR(
                 scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=96,
                 depth=[4,4,4,4,4],
                 num_heads=[4,4,4,4,4],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 2:
    model = PlantSR(
                 scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=32,
                 depth=[2,2,2,2],
                 num_heads=[2,2,2,2],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 3:
    model = PlantSR(scale=upscale, num_channels=3, num_features=64,n_resgroups=16,n_resblocks=4,reduction=16,ffn_scale=2,n_blocks=5)
model.to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(),lr=lr)

In [6]:
## load the pretrained model (if have one)

# model_path = 'ckpts/PlantSR_x2_best.pth'
# model.load_state_dict(torch.load(model_path), strict=True)

In [7]:
from tqdm import tqdm
import sys


start_epoch = 0
num_epochs =12

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()
        # if(batch_idx%10000) == 0:
        #     torch.save(model.state_dict(), 'outputs/PlantSR_x2_{}_{}.pth'.format(batch_idx,epoch+1))

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/PlantSRv2_x4_{}.pth'.format(epoch+1))
    

Epoch [1/12], Batch [17148/17148], Loss: 0.0261

Epoch [2/12], Batch [17148/17148], Loss: 0.0158

Epoch [3/12], Batch [17148/17148], Loss: 0.0315

Epoch [4/12], Batch [17148/17148], Loss: 0.0251

Epoch [5/12], Batch [17148/17148], Loss: 0.0283

Epoch [6/12], Batch [17148/17148], Loss: 0.0305

Epoch [7/12], Batch [17148/17148], Loss: 0.0154

Epoch [8/12], Batch [17148/17148], Loss: 0.0169

Epoch [9/12], Batch [17148/17148], Loss: 0.0205

Epoch [10/12], Batch [17148/17148], Loss: 0.0233

Epoch [11/12], Batch [17148/17148], Loss: 0.0313

Epoch [12/12], Batch [17148/17148], Loss: 0.0253



In [4]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale= 4
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)

device = 'cuda:0'

In [5]:
from models.PlantSRv2 import PlantSR

outPath = "outputs"
lr = 1e-4  

device = torch.device('cuda:0')
if upscale == 4:
   model = PlantSR(
                 scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=96,
                 depth=[4,4,4,4,4],
                 num_heads=[4,4,4,4,4],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 2:
    model = PlantSR(
                 scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=32,
                 depth=[2,2,2,2],
                 num_heads=[2,2,2,2],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 3:
    model = PlantSR(scale=upscale, num_channels=3, num_features=64,n_resgroups=16,n_resblocks=4,reduction=16,ffn_scale=2,n_blocks=5)
model_path = r'outputs/PlantSRv2_x4_28.pth'
model.load_state_dict(torch.load(model_path), strict=True)
model.to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=lr)

  model.load_state_dict(torch.load(model_path), strict=True)


In [6]:
from tqdm import tqdm
import sys

num_epochs = 8  # 继续训练8轮
start_epoch = 28  # 之前已经训练了12轮

# 继续训练的代码
for epoch in range(start_epoch, start_epoch + num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, start_epoch + num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()
        # if(batch_idx%10000) == 0:
        #     torch.save(model.state_dict(), 'outputs/PlantSR_x2_{}_{}.pth'.format(batch_idx,epoch+1))

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/PlantSRv2_x4_{}.pth'.format(epoch+1))
    

Epoch [29/36], Batch [17148/17148], Loss: 0.0161

Epoch [30/36], Batch [17148/17148], Loss: 0.0218

Epoch [31/36], Batch [17148/17148], Loss: 0.0231

Epoch [32/36], Batch [17148/17148], Loss: 0.0210

Epoch [33/36], Batch [17148/17148], Loss: 0.0204

Epoch [34/36], Batch [17148/17148], Loss: 0.0132

Epoch [35/36], Batch [17148/17148], Loss: 0.0238

Epoch [36/36], Batch [17148/17148], Loss: 0.0262



## Test

In [10]:
from models.PlantSRv2 import PlantSR
# from   PLtest import PlantSR
import torch

upscale = 4
device = 'cuda'
# model_path = r'outputs/PlantSR_x4_20000_10.pth'
model_path = r'outputs/PlantSRv2_x4_30.pth'
device = torch.device('cuda:0')
if upscale == 4:
    model = PlantSR(scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=96,
                 depth=[4,4,4,4,4],
                 num_heads=[4,4,4,4,4],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 2:
    model = PlantSR(scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=32,
                 depth=[2,2,2,2],
                 num_heads=[2,2,2,2],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 3:
    model = PlantSR(scale=upscale,num_features=64,n_resgroups=16,n_resblocks=4,reduction=16)

model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

  model.load_state_dict(torch.load(model_path), strict=True)


In [11]:
import cv2 as cv2
import numpy as np
import torch.nn.functional as F
from calulate_psnr_ssim import *
import os

test_psnr = 0
test_ssim = 0
image_count = 0

test_path = "./data/PlantSR_dataset/test"

for filename in os.listdir(test_path):
    if filename.endswith((".png",".jpg")):
        image_count+=1
        print(image_count)
        file_path = os.path.join(test_path, filename)
        
        hr_img = cv2.imread(file_path, cv2.IMREAD_COLOR).astype(np.float32) 
        h, w, _ = hr_img.shape

        ## make sure same dimension
        h -= h % upscale
        w -= w % upscale
        hr_img = hr_img[:h, :w]
        
        lr_image = cv2.resize(hr_img, (w // upscale, h // upscale), interpolation=cv2.INTER_LINEAR) 
        lr_image = lr_image/255.
        lr_image = torch.from_numpy(np.transpose(lr_image[:, :, [2, 1, 0]],
                                                (2, 0, 1))).float()
        lr_image = lr_image.unsqueeze(0).to(device)
        
        
        with torch.no_grad():
            output = model(lr_image)

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0)

        # if (output.shape[2] != hr_img.shape[2]) or (output.shape[3] != hr_img.shape[3]):
        #     output = F.interpolate(output, size=(h, w), mode='bilinear', align_corners=False)
            
        psnr = calc_psnr(hr_img, output)
        ssim = calc_ssim(hr_img, output)
        test_psnr += psnr
        test_ssim += ssim

test_psnr = test_psnr/image_count
test_ssim = test_ssim/image_count

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200


In [12]:
print('test psnr: {:.2f}'.format(test_psnr))
print('test ssim: {:.4f}'.format(test_ssim))

test psnr: 33.69
test ssim: 0.9049


## Inference

In [17]:
from models.PlantSRv2 import PlantSR
import torch

upscale = 2
device = 'cuda'
model_path = r'outputs/PlantSRv2_x2_8.pth'

device = torch.device('cuda:0')
if upscale == 4:
    model = PlantSR(scale=upscale,num_features=96,n_resgroups=16,n_resblocks=4,reduction=16)
if upscale == 2:
    model = PlantSR(scale=upscale,
                 img_size=64,
                 num_channels=3,
                 num_features=32,
                 depth=[2,2,2,2,2,2],
                 num_heads=[2,2,2,2,2,2],
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 use_chk=False,
                 img_range=1.,
                 resi_connection='1conv',
                 split_size=[8,8],
                 c_ratio=0.5,ffn_scale=2,n_blocks=4)
if upscale == 3:
    model = PlantSR(scale=upscale,num_features=64,n_resgroups=16,n_resblocks=4,reduction=16)

model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

  model.load_state_dict(torch.load(model_path), strict=True)


In [1]:
import os
import cv2
import numpy as np

input_folder = "G:/6666/duibizu/PlantSR/data/PlantSR_dataset/test"
output_folder = "G:/6666/duibizu/PlantSR/data/PlantSR_dataset/YourDatax2"

os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(input_folder):
    if filename.endswith((".jpg", ".jpeg", ".png")):
        img_path = os.path.join(input_folder, filename)

        img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) /255.
        h, w, _ = img.shape
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img)

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round().astype(np.uint8)

        save_path = os.path.join(output_folder, filename)

        cv2.imwrite(save_path, output)