In [None]:
!pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
!pip install easydict

In [None]:
import time
import torch.backends.cudnn as cudnn
from torch import nn
from easydict import EasyDict as edict
from models import SRResNet
from datasets import SRDataset
from utils import *

In [None]:
# config
config = edict()
config.csv_folder = '/home/mw/project/SRDataset'
config.HR_data_folder = '/home/mw/input/dataset76853/DIV2K_train_HR/DIV2K_train_HR'
config.LR_data_folder = '/home/mw/input/dataset76853/DIV2K_train_LR_bicubic_X4/DIV2K_train_LR_bicubic/X4'
config.crop_size = 96
config.scaling_factor = 4

# Model parameters
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16  # number of residual blocks

# Learning parameters
config.checkpoint = None  # path to model checkpoint, None if none
config.batch_size = 16  # batch size
config.start_epoch = 0  # start at this epoch
config.epochs = 20
config.workers = 4
config.beta = 1e-3  # the coefficient to weight the adversarial loss in the perceptual loss
config.print_freq = 50
config.lr = 1e-4
config.grad_clip = None    # clip if gradients are exploding

# Default device
config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

In [None]:
if config.checkpoint is None:
    model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                        n_channels=n_channels, n_blocks=n_blocks, scaling_factor=config.scaling_factor)
    # Initialize the optimizer
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=config.lr)

else:
    checkpoint = torch.load(config.checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']

In [None]:
# Move to default device
model = model.to(config.device)
criterion = nn.MSELoss().to(config.device)

In [None]:
# Custom dataloaders
train_dataset = SRDataset(split='train', config=config)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers,
                                            pin_memory=True)  # note that we're passing the collate function here

In [None]:
def train(train_loader, model, criterion, optimizer, epoch):
    """
    One epoch's training.

    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: content loss function (Mean Squared-Error loss)
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    model.train()  # training mode enables batch normalization

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()

    # Batches
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to default device
        lr_imgs = lr_imgs.to(config.device)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(config.device)  # (batch_size (N), 3, 96, 96), in [-1, 1]

        # Forward prop.
        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]

        # Loss
        loss = criterion(sr_imgs, hr_imgs)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Update model
        optimizer.step()

        # Keep track of loss
        losses.update(loss.item(), lr_imgs.size(0))

        # Keep track of batch time
        batch_time.update(time.time() - start)

        # Reset start time
        start = time.time()

        # Print status
        if i % config.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]----'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader),
                                                                    batch_time=batch_time,
                                                                    data_time=data_time, loss=losses))
    del lr_imgs, hr_imgs, sr_imgs  # free some memory since their histories may be stored


In [None]:
# Epochs
for epoch in range(config.start_epoch, config.epochs):
    # One epoch's training
    train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch)
    # Save checkpoint
    torch.save({'epoch': epoch,
                'model': model,
                'optimizer': optimizer},
                'checkpoint_srresnet.pth.tar')