<a href="https://colab.research.google.com/github/abel-bernabeu/autoencoder/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [109]:
!rm runs -rf
%load_ext tensorboard
%tensorboard --logdir runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 4874), started 0:26:29 ago. (Use '!kill 4874' to kill it.)

In [110]:
hparams = {
    'batch_size': 4,
    'device': 'cuda',
    'max_dataset_size': 100,
    'train_dataset_size': 50,
    'test_dataset_size': 50,
    'log_interval': 2,
    'num_epochs': 100,
    'num_workers': 4,
}

In [111]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils
import torchvision.transforms as transforms
import numpy as np
import autoencoder.datasets
import autoencoder.transforms
import autoencoder.models
import datetime

In [112]:
writer = SummaryWriter('')

In [113]:
crops = autoencoder.datasets.CropsDataset("./data/image_dataset_part-a", 224, 224, subset_size=hparams['max_dataset_size'], assume_fixed_size=False)

Collecting sizes from images in ./data/image_dataset_part-a: 100%|██████████| 100/100 [00:00<00:00, 16921.39it/s]


In [114]:
# Show a few crops
few_crops = [ transforms.ToTensor()(crop[0]) for crop in [crops[index] for index in range(16)]]
grid = torchvision.utils.make_grid(few_crops, nrow=4)
writer.add_image("1) a few crops", grid)
writer.flush()

In [115]:
# Random split in train and test sets
train_crops, test_crops = torch.utils.data.random_split(crops, [hparams['train_dataset_size'], hparams['test_dataset_size'],])

In [116]:
train_input_transform = transforms.Compose([
  autoencoder.transforms.ConvertToGray(),
  transforms.ToTensor(),
])

train_output_transform = transforms.Compose([transforms.ToTensor()])

train_xydims_samples = autoencoder.datasets.XYDimsDataset(train_crops, train_input_transform, train_output_transform)

In [117]:
# Show x from a few train samples
few_train_x = [ sample[0] for sample in [train_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_train_x, nrow=4)
writer.add_image("2) x from a few train samples", grid)
writer.flush()

In [118]:
# Show y from a few train samples
few_train_y = [ sample[1] for sample in [train_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_train_y, nrow=4)
writer.add_image("3) y from a few train samples", grid)
writer.flush()

In [119]:
test_input_transform = transforms.Compose([
  autoencoder.transforms.ConvertToGray(),
  transforms.ToTensor(),
])

test_output_transform = transforms.Compose([transforms.ToTensor()])

test_xydims_samples = autoencoder.datasets.XYDimsDataset(test_crops, test_input_transform, test_output_transform)

In [120]:
# Show x from a few test samples
few_test_x = [ sample[0] for sample in [test_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_test_x, nrow=4)
writer.add_image("4) x from a few test samples", grid)
writer.flush()

In [121]:
# Show y from a few train samples
few_test_y = [ sample[1] for sample in [test_xydims_samples[index] for index in range(4)] ]
grid = torchvision.utils.make_grid(few_test_y, nrow=4)
writer.add_image("5) y from a few test samples", grid)
writer.flush()

In [122]:
train_loader = torch.utils.data.DataLoader(train_xydims_samples, batch_size=hparams['batch_size'], shuffle=True, num_workers=hparams['num_workers'])
test_loader = torch.utils.data.DataLoader(test_xydims_samples, batch_size=hparams['batch_size'], shuffle=False, num_workers=hparams['num_workers'])

In [123]:
def train_epoch(train_loader, model, optimizer, criterion, hparams):
    np.random.seed(datetime.datetime.now().microsecond)
    model.train()
    device = hparams['device']
    losses = []
    for batch_idx, (data, target, _, _) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return np.mean(losses)

def eval_epoch(val_loader, model, criterion, hparams):
    np.random.seed(0)
    model.eval()
    device = hparams['device']
    eval_losses = []
    with torch.no_grad():
        for data, target, _, _ in val_loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            eval_losses.append(criterion(output, target).item())
    return np.mean(eval_losses)

In [124]:
class Autoencoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
            nn.ReLU(),
        )

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.upsample(x)
        return x

In [125]:
model = Autoencoder()
model.to(hparams['device'])
optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
criterion = nn.MSELoss()

num_epochs = hparams['num_epochs']

try:
    for epoch in range(num_epochs):
        train_loss = train_epoch( train_loader, model, optimizer, criterion, hparams)
        test_loss = eval_epoch(test_loader, model, criterion, hparams)
        writer.add_scalar("train loss", train_loss, global_step=epoch)
        writer.add_scalar("test loss", test_loss, global_step=epoch)
        writer.flush()
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

In [126]:
writer.close()