# Experiment
Adapted from [Training with PyTorch](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html)

In [1]:
from datetime import datetime
from pathlib import Path

import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms as tv_transfroms

from models.fcn_factory import FcnFactory
from preprocess.datasets import make_hou_dataset
from preprocess import transforms as custom_transforms

#### Check CUDA is available

In [2]:
print(f'CUDA version: {torch.version.cuda}')
print(f'CUDA available: {torch.cuda.is_available()}')
print(f'Device ID: {torch.cuda.current_device()}')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

CUDA version: 11.8
CUDA available: True
Device ID: 0


#### Hyperparameters

In [3]:
# Config
RANDOM_STATE = 42
rng = torch.Generator().manual_seed(RANDOM_STATE)
N_BATCHES_FOR_RECORD = 10
# Runs
EPOCHS = 10
# Data loading
BATCH_SIZE = 32
SHUFFLE = True

In [4]:
factory = FcnFactory(n_classes=2)
model = factory.make_fcn('resnet50').to(device)

#### Model input/output transforms

In [5]:
def get_output_shape(model, in_transform) -> list:
    dummy_input = torch.randn((1, 3, 256, 256)).to(device)
    return list(model(in_transform(dummy_input))['out'].shape[-2:])

In [6]:
input_transforms = factory.input_transforms
output_h_w = get_output_shape(model, input_transforms)
mask_transforms = tv_transfroms.Compose(
    [
        tv_transfroms.Resize(output_h_w),
        tv_transfroms.Lambda(custom_transforms.one_hot),
    ]
)
# some checks
assert output_h_w == [520, 520] # this is the size for this case

#### Prepare Datasets

In [7]:
# Hou
# Excluding this directory because the masks apper to be wrong
EXCLUDE_DIRS = (Path('data/Hou/PV03_Ground_WaterSurface'),)
hou_ds = make_hou_dataset(
    EXCLUDE_DIRS,
    img_transforms=input_transforms,
    mask_transforms=mask_transforms,
)

# Split datasets
train_ds, val_ds, test_ds = random_split(hou_ds, [0.6, 0.2, 0.2], generator=rng)

# Loading the data
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, generator=rng)
val_loader = DataLoader(val_ds, BATCH_SIZE)

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters())

In [9]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss_mean = 0.

    # Here, we use enumerate(train_loader) instead of
    # iter(train_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, batch in enumerate(train_loader):
        inputs, labels = batch
        # move data to GPU
        inputs = inputs.to(device)
        labels = labels.to(device)


        # Zero your gradients for every batch!
        optimiser.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)['out']

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimiser.step()

        # Gather data and report
        running_loss += loss.item()
        if i % N_BATCHES_FOR_RECORD == N_BATCHES_FOR_RECORD - 1:
            last_loss_mean = running_loss / N_BATCHES_FOR_RECORD # loss per batch
            print(f'\tBatches {i-N_BATCHES_FOR_RECORD+2}-{i+1} mean loss: {last_loss_mean}')
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss_mean, tb_x)
            running_loss = 0.

    return last_loss_mean

In [10]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/ResNet-50{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_loader):
            vinputs, vlabels = vdata
            # move to GPU
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)['out']
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
	Batches 1-10 mean loss: 0.4820940881967545
	Batches 11-20 mean loss: 0.3197287768125534
	Batches 21-30 mean loss: 0.2774208799004555
	Batches 31-40 mean loss: 0.22921065837144852
	Batches 41-50 mean loss: 0.20369807332754136
LOSS train 0.20369807332754136 valid 0.18774618208408356
EPOCH 2:
	Batches 1-10 mean loss: 0.19680357128381729
	Batches 11-20 mean loss: 0.1758433535695076
	Batches 21-30 mean loss: 0.17454831302165985
	Batches 31-40 mean loss: 0.1855638101696968
	Batches 41-50 mean loss: 0.1983269691467285
LOSS train 0.1983269691467285 valid 0.19255037605762482
EPOCH 3:
	Batches 1-10 mean loss: 0.1682751163840294
	Batches 11-20 mean loss: 0.16786079108715057
	Batches 21-30 mean loss: 0.17362866178154945
	Batches 31-40 mean loss: 0.16725173443555832
	Batches 41-50 mean loss: 0.17151171565055848
LOSS train 0.17151171565055848 valid 0.1489190310239792
EPOCH 4:
	Batches 1-10 mean loss: 0.14808582365512848
	Batches 11-20 mean loss: 0.18569454848766326
	Batches 21-30 mean loss

In [12]:
import pickle

with open(f'saved_models/test_data/test_ds_{timestamp}', 'wb') as f:
    pickle.dump(test_ds, f)