## Setup

In [None]:
#Download Dataset
#!wget http://weegee.vision.ucmerced.edu/datasets/UCMerced_LandUse.zip
!unzip -o -q UCMerced_LandUse.zip
!rm UCMerced_LandUse.zip
!mv UCMerced_LandUse data

In [1]:
from src.ds import UCMerced, HQLQ
from src.util import random_split_ratio, psnr
from src.srresnet import SRResNet
from torch.utils.data import DataLoader
import torch
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import random

#fix random seeds
seed=31415
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)



In [2]:
#Data
dataset = UCMerced("data")
#image augmentation transformations for train, validation and test
augments = [
    [
        transforms.RandomCrop((96, 96)),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
    ],
    [],
    [],
]

scaling_factor = 4
valratio, testratio = 0.1, 0.1
trainds, valds, testds = (
    HQLQ(ds, aug, scalingfactor=scaling_factor)
    for ds, aug in zip(random_split_ratio(dataset, valratio, testratio), augments)
)

## Parameters

In [3]:
# Model parameters -- taken from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16  # number of residual blocks

# Learning parameters
batch_size = 64  # batch size
try:
    start_epoch = epoch  # start at old epoch if already defined from earlier execution
except:
    start_epoch = 0
iterations = 5e4  # number of training iterations
workers = 4  # number of workers for loading data in the DataLoader
lr = 1e-4  # learning rate
device = "cuda" if torch.cuda.is_available() else "cpu"


train_loader = DataLoader(trainds, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
val_loader = DataLoader(valds, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

In [4]:
import time


def train(train_loader, model, criterion, optimizer):
    """
    train one epoch
    :param train_loader: DataLoader
    :param model: model
    :param criterion: loss function
    :param optimizer: optimizer
    returns (losses,data_time,batch_time)
    """
    model.train()
    start = time.time()
    data_time = 0
    run_time = 0
    losses = []
    for i, (lq, hq) in enumerate(train_loader):
        # Move to device
        lq = lq.to(device)
        hq = hq.to(device)

        data_time += time.time() - start

        # Forward
        sr = model(lq)

        # Loss
        loss = criterion(sr, hq)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Update model
        optimizer.step()

        losses.append(loss.item())
        run_time += time.time() - start
        start = time.time()

    del lq, hq, sr
    return (losses, data_time, run_time)


In [5]:
def validate(val_loader,model,criterion):
    """
    validate one epoch
    :param tval_loader: DataLoader 
    :param model: model
    :param criterion: loss function
    returns loss and psnr
    """

    with torch.no_grad():
        model.eval()
        avg_loss,avg_psnr = 0., 0.
        
        for lq, hq in val_loader:
            lq, hq = lq.to(device), hq.to(device)
            
            #forward
            sr = model(lq)
            
            #criteria
            avg_loss += criterion(sr, hq).item() / len(val_loader)
            avg_psnr += psnr(hq,sr) / len(val_loader)
            
        return (avg_loss,avg_psnr)

In [12]:
#initialize model without overwriting
try: 
    model
    print('model not changed!')
except:
    model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                  n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
 
    #load pretrained weights, extracted from checkpoint at https://drive.google.com/drive/folders/12OG-KawSFFs6Pah89V4a_Td-VcwMBE5i?usp=sharing
    model.load_state_dict(torch.load('weights/sgrvinod.zip'))
    
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                              lr=lr)
model = model.to(device)
criterion = torch.nn.MSELoss().to(device)

## Train

In [13]:
#keep log for plot in the end
try:
    logs
except:
    logs=defaultdict(list)

epochs = int(iterations // len(train_loader) + 1)
for epoch in (range(start_epoch, epochs)):
    (train_losses,data_time,batch_time) = train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer)
    train_loss=np.mean(train_losses)
    logs['train'].append((epoch,train_loss))

    if epoch%2==0:
        (val_loss,val_psnr) = validate(val_loader=val_loader,model=model,criterion=criterion)
        logs['val'].append((epoch,val_loss,val_psnr))
        print(f"[{epoch}/{epochs}] -- Time: Data {data_time:.2f}s, Epoch {batch_time:.2f}s -- TrainLoss {train_loss:.4f} -- ValLoss {val_loss:.4f} -- ValPSNR {val_psnr:.4f}")
    else:
        print(f"[{epoch}/{epochs}] -- Time: Data {data_time:.2f}s, Epoch {batch_time:.2f}s -- TrainLoss {train_loss:.4f}")

[0/1852] -- Time: Data 0.46s, Epoch 4.68s -- TrainLoss 0.0151 -- ValLoss 0.0128 -- ValPSNR 18.9427
[1/1852] -- Time: Data 0.41s, Epoch 4.26s -- TrainLoss 0.0126
[2/1852] -- Time: Data 0.41s, Epoch 4.29s -- TrainLoss 0.0125 -- ValLoss 0.0112 -- ValPSNR 19.5784
[3/1852] -- Time: Data 0.39s, Epoch 4.34s -- TrainLoss 0.0121
[4/1852] -- Time: Data 0.41s, Epoch 4.27s -- TrainLoss 0.0123 -- ValLoss 0.0120 -- ValPSNR 19.2520
[5/1852] -- Time: Data 0.41s, Epoch 4.36s -- TrainLoss 0.0122
[6/1852] -- Time: Data 0.40s, Epoch 4.30s -- TrainLoss 0.0118 -- ValLoss 0.0124 -- ValPSNR 19.1384
[7/1852] -- Time: Data 0.41s, Epoch 4.29s -- TrainLoss 0.0119
[8/1852] -- Time: Data 0.39s, Epoch 4.27s -- TrainLoss 0.0119 -- ValLoss 0.0115 -- ValPSNR 19.4391
[9/1852] -- Time: Data 0.39s, Epoch 4.30s -- TrainLoss 0.0119
[10/1852] -- Time: Data 0.39s, Epoch 4.38s -- TrainLoss 0.0117 -- ValLoss 0.0107 -- ValPSNR 19.8052
[11/1852] -- Time: Data 0.38s, Epoch 4.26s -- TrainLoss 0.0118
[12/1852] -- Time: Data 0.39s, E

In [14]:
#save results

import pandas as pd

torch.save(model.state_dict(),'weights/transfer.zip')

df=pd.DataFrame(*list(zip(*logs['train']))[::-1],columns=['train_loss'])
df=df.join(pd.DataFrame([(i[1],i[2]) for i in logs['val']],index=[i[0] for i in logs['val']],columns=['val_loss','val_psnr']))
df=df.rename_axis('epoch')
df.to_csv('log_transfer.csv')

In [20]:
#save results

import pandas as pd

torch.save(model.state_dict(),'weights/transfer.zip')

df=pd.DataFrame(*list(zip(*logs['train']))[::-1],columns=['train_loss'])
df=df.join(pd.DataFrame([(i[1],i[2]) for i in logs['val']],index=[i[0] for i in logs['val']],columns=['val_loss','val_psnr']))
df=df.rename_axis('epoch')
df.to_csv('log_transfer.csv')