In [1]:
import sys
sys.path.append('../')
from models.SRCNN import SRCNN

## Load data

In [2]:
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import random

class SRDataset(Dataset):
    def __init__(self, root, upscale_factor):
        super(SRDataset, self).__init__()
        self.hr_path = os.path.join(root, 'train_64')
        self.upscale_factor = upscale_factor
        self.hr_filenames = sorted(os.listdir(self.hr_path))

    def __getitem__(self, index):
        hr_image = cv2.imread(os.path.join(self.hr_path, self.hr_filenames[index]))
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        h, w, _ = hr_image.shape

        ## make sure same demension
        h -= h % self.upscale_factor
        w -= w % self.upscale_factor
        hr_image = hr_image[:h, :w]

        lr_image = cv2.resize(hr_image, (int(w // self.upscale_factor),int(h // self.upscale_factor)), interpolation=cv2.INTER_LINEAR)

        ## data enhancement
        if random.random() > 0.5:  
            lr_image = cv2.flip(lr_image, 1)
            hr_image = cv2.flip(hr_image, 1)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)

        return lr_image, hr_image

    def __len__(self):
        return len(self.hr_filenames)

In [3]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

upscale= 4
train_dataset = SRDataset(root='./data/PlantSR_dataset/', upscale_factor=upscale)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

device = 'cuda:0'

##  upscale = 2/3/4
model = SRCNN().to(device)

In [4]:
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train

In [1]:
# import sys
# ## ignore warnings

# class HiddenPrints:
#     def write(self, msg):
#         pass

# try:
#     sys.stderr = HiddenPrints()
# except:
#     pass


In [None]:
from tqdm import tqdm
import sys


start_epoch = 0
num_epochs = 10

for epoch in range(start_epoch,num_epochs):
    model.train()
    for batch_idx, (lr_images, hr_images) in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        sr_images = model(lr_images.float())

        loss = criterion(sr_images, hr_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 1 == 0:
            sys.stdout.write('\rEpoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'
                             .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
            sys.stdout.flush()

    print("\n")
    if (epoch+1) % 1 == 0:
        torch.save(model.state_dict(), 'outputs/SRCNN_x4_{}.pth'.format(epoch+1))


Epoch [1/10], Batch [38997/38997], Loss: 0.0343

Epoch [2/10], Batch [38997/38997], Loss: 0.0525

Epoch [3/10], Batch [38997/38997], Loss: 0.0490

Epoch [4/10], Batch [31049/38997], Loss: 0.0184

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [5/10], Batch [15059/38997], Loss: 0.0237

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [5/10], Batch [23039/38997], Loss: 0.0250

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [5/10], Batch [30784/38997], Loss: 0.0241

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [5/10], Batch [38614/38997], Loss: 0.0156

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [6/10], Batch [15335/38997], Loss: 0.0286

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [6/10], Batch [23544/38997], Loss: 0.0204

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [6/10], Batch [38997/38997], Loss: 0.0078

Epoch [7/10], Batch [1671/38997], Loss: 0.0328

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [7/10], Batch [10178/38997], Loss: 0.0263

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [7/10], Batch [19084/38997], Loss: 0.0274

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [7/10], Batch [28103/38997], Loss: 0.0247

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [7/10], Batch [37311/38997], Loss: 0.0244

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [8/10], Batch [7589/38997], Loss: 0.0290

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [8/10], Batch [16929/38997], Loss: 0.0193

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [8/10], Batch [26313/38997], Loss: 0.0174

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch [8/10], Batch [35224/38997], Loss: 0.0266