<a href="https://colab.research.google.com/github/abel-bernabeu/autoencoder/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Launch Tensorboard
#!rm runs -rf # Uncomment to delete all the previous Tensorboard runs
%load_ext tensorboard
%tensorboard --logdir runs
%reload_ext tensorboard

In [None]:
# Create directory for training checkpoints
#!rm params -rf # Uncomment to delete all the checkpoints
%mkdir -p params

In [None]:
hparams = {
    'batch_size': 8,
    'device': 'cuda',
    'max_dataset_size': 20,
    'train_dataset_size':10,
    'test_dataset_size': 10,
    'num_epochs': 1000,
    'num_workers': 4,
    'params' : "./params/colorizer.pt",
    'continue_with_best_model' : False,
    'checkpointing_freq' : 20
}

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils
import torchvision.transforms as transforms
import numpy as np
import autoencoder.datasets
import autoencoder.transforms
import autoencoder.models
import datetime
import os

In [None]:
writer = SummaryWriter('')

In [None]:
crops = autoencoder.datasets.CropsDataset("./data/image_dataset_part-a", 224, 224, subset_size=hparams['max_dataset_size'], assume_fixed_size=True)

In [None]:
# Show a few crops
few_crops = [ transforms.ToTensor()(crop[0]) for crop in [crops[index] for index in range(16)]]
grid = torchvision.utils.make_grid(few_crops, nrow=4)
writer.add_image("1) a few crops", grid)
writer.flush()

In [None]:
# Random split in train and test sets
train_crops, test_crops = torch.utils.data.random_split(crops, [hparams['train_dataset_size'], hparams['test_dataset_size'],])

In [None]:
corruption = autoencoder.transforms.ConvertToGray()

In [None]:
train_input_transform = transforms.Compose([
  corruption,
  transforms.ToTensor(),
])

train_output_transform = transforms.Compose([transforms.ToTensor()])

train_xydims_samples = autoencoder.datasets.XYDimsDataset(train_input_transform, train_output_transform, dataset=train_crops)

In [None]:
# Show x from a few train samples
few_train_x = [ sample[0] for sample in [train_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_train_x, nrow=4)
writer.add_image("2) x from a few train samples", grid)
writer.flush()

In [None]:
# Show y from a few train samples
few_train_y = [ sample[1] for sample in [train_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_train_y, nrow=4)
writer.add_image("3) y from a few train samples", grid)
writer.flush()

In [None]:
test_input_transform = transforms.Compose([
  corruption,
  transforms.ToTensor(),
])

test_output_transform = transforms.Compose([transforms.ToTensor()])

test_xydims_samples = autoencoder.datasets.XYDimsDataset(test_input_transform, test_output_transform, dataset=test_crops)

In [None]:
# Show x from a few test samples
few_test_x = [ sample[0] for sample in [test_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_test_x, nrow=4)
writer.add_image("4) x from a few test samples", grid)
writer.flush()

In [None]:
# Show y from a few train samples
few_test_y = [ sample[1] for sample in [test_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_test_y, nrow=4)
writer.add_image("5) y from a few test samples", grid)
writer.flush()

In [None]:
train_loader = torch.utils.data.DataLoader(train_xydims_samples, batch_size=hparams['batch_size'], shuffle=True, num_workers=hparams['num_workers'])
test_loader = torch.utils.data.DataLoader(test_xydims_samples, batch_size=hparams['batch_size'], shuffle=False, num_workers=hparams['num_workers'])

In [None]:
def train_epoch(train_loader, model, optimizer, criterion, hparams):
    np.random.seed(datetime.datetime.now().microsecond)
    model.train()
    device = hparams['device']
    losses = []
    for data, target, _, _ in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return np.mean(losses)

def test_epoch(test_loader, model, criterion, hparams):
    np.random.seed(0)
    model.eval()
    device = hparams['device']
    eval_losses = []
    with torch.no_grad():
        for data, target, _, _ in test_loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            eval_losses.append(criterion(output, target).item())
    return np.mean(eval_losses)

def inference(model, inputs_list):
    """
    Do an inference with the model for each input tensor from the provided list and
    return a list with the inference results
    """
    result = []
    for x in inputs_list:
        num_channels = x.shape[0]
        height = x.shape[1]
        width = x.shape[2]
        single_element_batch = x.clone().detach().reshape(1, num_channels, height, width)
        single_element_batch = single_element_batch.to(hparams['device'])
        model.to(hparams['device'])
        model.eval()
        output = model(single_element_batch)
        output = output.reshape(num_channels, height, width)
        result.append(output)
    return result


In [None]:
class Autoencoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, padding_mode='replicate'),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, padding_mode='replicate'),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),            
            nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = torch.tanh(x)
        x = (x + 1) * 0.5
        return x

In [None]:
# Move few_test_x to the same device where the inferences will be left
for index in range(len(few_test_x)):
  few_test_x[index] = few_test_x[index].to(hparams['device'])

# Move few_test_y to the same device where the inferences will be left
for index in range(len(few_train_y)):
  few_train_y[index] = few_train_y[index].to(hparams['device'])

# Move few_train_x to the same device where the inferences will be left
for index in range(len(few_train_x)):
  few_train_x[index] = few_train_x[index].to(hparams['device'])

# Move few_train_y to the same device where the inferences will be left
for index in range(len(few_test_y)):
  few_test_y[index] = few_test_y[index].to(hparams['device'])


In [None]:
# Instantiate model, optimer and loss
model = Autoencoder()
optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
criterion = nn.MSELoss()

In [None]:
# Move model to device
model = model.to(hparams['device'])

In [None]:
# Restore model and optimizer from previous checkpoint or create new checkpoint from scratch
if os.path.isfile(hparams['params']):
    print("Restoring from previous checkpoint")
    checkpoint = torch.load(hparams['params'])    
    if hparams['continue_with_best_model']:
        model.load_state_dict(checkpoint['best_model'])
    else:
        model.load_state_dict(checkpoint['last_model'])        
    optimizer.load_state_dict(checkpoint['optimizer'])
else:
    next_epoch = 0
    best_model_params = model.state_dict()
    checkpoint = {
        'epoch' : 0,
        'best_train_loss': None,
        'best_model': model.state_dict(),
        'last_model': model.state_dict(),
        'optimizer' : optimizer.state_dict()
    }

# Run a number of training epochs
start = checkpoint['epoch']
end = hparams['num_epochs']

if start < end - 1 or checkpoint['best_train_loss'] is None:
    
    try:
        for epoch in range(start, end):

            train_loss = train_epoch(train_loader, model, optimizer, criterion, hparams)
            
            test_loss = test_epoch(test_loader, model, criterion, hparams)            

            if epoch == hparams['num_epochs'] - 1 or epoch % hparams['checkpointing_freq'] == 0:

                print('Saving checkpoint for epoch ' + str(epoch))
                checkpoint['epoch'] = epoch

                if checkpoint['best_train_loss'] is None or train_loss < checkpoint['best_train_loss']:
                    print('New best model found!')                
                    checkpoint['best_train_loss'] = train_loss
                    checkpoint['best_model'] = model.state_dict()

                checkpoint['last_model'] = model.state_dict()

                checkpoint['optimizer'] = optimizer.state_dict()
                
                torch.save(checkpoint, hparams['params'])

            writer.add_scalar("train loss", train_loss, global_step=epoch)
                        
            writer.add_scalar("test loss", test_loss, global_step=epoch)

            # Show inferences with a few training samples
            few_train_y_hat = inference(model, few_train_x)
            grid = torchvision.utils.make_grid(few_train_y + few_train_x + few_train_y_hat, nrow=4)
            writer.add_image("a few train samples, one column per sample in (y, x, y_hat) format", grid, global_step=epoch)

            # Show inferences with a few test samples
            few_test_y_hat = inference(model, few_test_x)
            grid = torchvision.utils.make_grid(few_test_y + few_test_x + few_test_y_hat, nrow=4)
            writer.add_image("a few test samples, one column per sample in (y, x, y_hat) format", grid, global_step=epoch)

            writer.flush()

    except KeyboardInterrupt:

        print('Exiting from training early')

In [None]:
writer.close()