In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# import utils as u
#u.create_data_lists(["C:\\Users\\gusta\\Downloads\\train2014\\train2014"], ["C:\\Users\\gusta\\Downloads\\val2014\\val2014"], 100, "imgs_list")

In [3]:
# Testing training with splus images
#import utils as u
#u.create_data_lists(["H:\\ImagensIA\\splus\\train"], ["H:\\ImagensIA\\splus\\test"], 100, "imgs_list")


Creating data lists... this may take some time.

There are 170 images in the training data.

There are 34 images in the H:\ImagensIA\splus\test test data.

JSONS containing lists of Train and Test images have been saved to imgs_list



In [1]:
import time
import torch.backends.cudnn as cudnn
import torch
from torch import nn
from models import SRResNet
from datasets import SRDataset
from utils import *

In [2]:
# Data parameters
data_folder = './imgs_list/'  # folder with JSON data files
crop_size = 96  # crop size of target HR images
scaling_factor = 4  # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor

# Model parameters
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16  # number of residual blocks

# Learning parameters
checkpoint = None  # path to model checkpoint, None if none
batch_size = 96  # batch size
start_epoch = 0  # start at this epoch
iterations = 1e6  # number of training iterations
workers = 8  # number of workers for loading data in the DataLoader
print_freq = 100  # print training status once every __ batches
lr = (1e-4)/2  # learning rate
grad_clip = None  # clip if gradients are exploding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cudnn.benchmark = True

In [3]:
device

device(type='cuda')

In [4]:
def train(train_loader, model, criterion, optimizer, epoch):
    """
    One epoch's training.

    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: content loss function (Mean Squared-Error loss)
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    print('Training epoch: {}'.format(epoch))
    model.train()  # training mode enables batch normalization

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()

    # Batches
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to default device
        lr_imgs = lr_imgs.to(device, non_blocking=True)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(device, non_blocking=True)  # (batch_size (N), 3, 96, 96), in [-1, 1]

        # Forward prop.
        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]

        # Loss
        loss = criterion(sr_imgs, hr_imgs)  # scalar

        # Backward prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)

        # Update model
        optimizer.step()

        # Keep track of loss
        losses.update(loss.item(), lr_imgs.size(0))

        # Keep track of batch time
        batch_time.update(time.time() - start)

        # Reset start time
        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0:3d}][{1:4d}/{2:4d}]  ----  '
                    'Batch Time: {batch_time.val:6.3f} ({batch_time.avg:6.3f})  ----  '
                    'Data Time: {data_time.val:6.3f} ({data_time.avg:6.3f})  ----  '
                    'Loss: {loss.val:7.4f} ({loss.avg:7.4f})'.format(
                        epoch, i, len(train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses
                ))
        
    del lr_imgs, hr_imgs, sr_imgs  # free some memory since their histories may be stored
    print()


In [5]:
"""
Training.
"""
global start_epoch, epoch, checkpoint

checkpoint = "checkpoints/checkpoint_srresnet_37_.pth.tar"

# Initialize model or load checkpoint
if checkpoint is None:
    print('Initializing model...')
    model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                        n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
    # Initialize the optimizer
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=lr)

else:
    print(f'Loading checkpoint: {checkpoint}')
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']

# Move to default device
model = model.to(device)
criterion = nn.MSELoss().to(device)

# Custom dataloaders
train_dataset = SRDataset(data_folder,
                            split='train',
                            crop_size=crop_size,
                            scaling_factor=scaling_factor,
                            lr_img_type='imagenet-norm',
                            hr_img_type='[-1, 1]')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers,
                                            pin_memory=True)  # note that we're passing the collate function here

# Total number of epochs to train for
epochs = int(iterations // len(train_loader) + 1)

Loading checkpoint: checkpoints/checkpoint_srresnet_37_.pth.tar


In [6]:
torch.cuda.empty_cache()

In [7]:
# Epochs
for epoch in range(start_epoch, epochs):
    # One epoch's training
    train(train_loader=train_loader,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            epoch=epoch)

    # Save checkpoint
    torch.save({'epoch': epoch,
                'model': model,
                'optimizer': optimizer},
                f'checkpoints/checkpoint_srresnet_{epoch}_.pth.tar')


Training epoch: 38
Epoch: [ 38][   0/   2]  ----  Batch Time: 10.737 (10.737)  ----  Data Time:  5.625 ( 5.625)  ----  Loss:  0.0409 ( 0.0409)

Training epoch: 39
Epoch: [ 39][   0/   2]  ----  Batch Time:  5.656 ( 5.656)  ----  Data Time:  5.406 ( 5.406)  ----  Loss:  0.0401 ( 0.0401)

Training epoch: 40
Epoch: [ 40][   0/   2]  ----  Batch Time:  5.534 ( 5.534)  ----  Data Time:  5.297 ( 5.297)  ----  Loss:  0.0423 ( 0.0423)

Training epoch: 41
Epoch: [ 41][   0/   2]  ----  Batch Time:  5.488 ( 5.488)  ----  Data Time:  5.244 ( 5.244)  ----  Loss:  0.0424 ( 0.0424)

Training epoch: 42
Epoch: [ 42][   0/   2]  ----  Batch Time:  5.594 ( 5.594)  ----  Data Time:  5.357 ( 5.357)  ----  Loss:  0.0413 ( 0.0413)

Training epoch: 43
Epoch: [ 43][   0/   2]  ----  Batch Time:  5.674 ( 5.674)  ----  Data Time:  5.438 ( 5.438)  ----  Loss:  0.0420 ( 0.0420)

Training epoch: 44
Epoch: [ 44][   0/   2]  ----  Batch Time:  5.573 ( 5.573)  ----  Data Time:  5.338 ( 5.338)  ----  Loss:  0.0421 ( 0

KeyboardInterrupt: 