In [1]:
import sys
sys.path.append('../')
from models.VDSR import VDSR

## Data preparing

In [2]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import random

class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        # h -= h % upscale
        # w -= w % upscale
        # hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)
        lr_image = cv2.resize(lr_image, (w,h), interpolation=cv2.INTER_LINEAR)
        
        ## data enhancement
        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)

        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

In [3]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale= 3
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

device = 'cuda:0'

##  upscale = 2/3/4
model = VDSR().to(device)

In [4]:
## load the pretrained model (if have one)

# model_path = 'outputs/VDSR_x3_1.pth'
# model.load_state_dict(torch.load(model_path), strict=True)

In [5]:
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train

In [None]:
import sys
## ignore warnings

class HiddenPrints:
    def write(self, msg):
        pass

try:
    sys.stderr = HiddenPrints()
except:
    pass




In [None]:
from tqdm import tqdm
import sys


start_epoch = 0
num_epochs = 10

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/VDSR_x3_{}.pth'.format(epoch+1))


Epoch [1/10], Batch [38997/38997], Loss: 0.0176

Epoch [2/10], Batch [36549/38997], Loss: 0.0280