# Experiment
Adapted from [Training with PyTorch](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html)

In [1]:
from datetime import datetime
from pathlib import Path
import pickle

import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import v2 as tv_transfroms

from models.fcn_factory import FcnFactory
from preprocess.datasets import make_hou_dataset, make_kasmi_ign_dataset
from preprocess import transforms as custom_transforms

#### Check CUDA is available

In [2]:
print(f'CUDA version: {torch.version.cuda}')
print(f'CUDA available: {torch.cuda.is_available()}')
print(f'Device ID: {torch.cuda.current_device()}')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

CUDA version: 11.8
CUDA available: True
Device ID: 0


#### Config

In [3]:
# Config
RANDOM_STATE = 42
rng = torch.Generator().manual_seed(RANDOM_STATE)
N_BATCHES_FOR_RECORD = 10

# Warm Start (set to `None` to train head from scratch)
WARM_START_MODEL_PATH = None # 'saved_models/model_20240324_152229_4'

# Save Test Dataset (if test_ds is set elsewhere then it will be saved)
test_ds = None

#### Hyperparameters

In [4]:
# Runs
EPOCHS = 5

# Data loading
BATCH_SIZE = 32
SHUFFLE = True

#### Model & Transforms

In [5]:
factory = FcnFactory(n_classes=2)
model = factory.make_fcn('resnet50').to(device)
if WARM_START_MODEL_PATH:
    model.load_state_dict(torch.load(WARM_START_MODEL_PATH))

In [6]:
def get_output_shape(model, in_transform) -> list:
    dummy_input = torch.randn((1, 3, 256, 256)).to(device)
    return list(model(in_transform(dummy_input))['out'].shape[-2:])

In [7]:
input_transforms = tv_transfroms.Compose(
    [
        tv_transfroms.Lambda(custom_transforms.rgba_to_rgb),
        factory.input_transforms,
    ]
)
output_h_w = get_output_shape(model, input_transforms)
mask_transforms = tv_transfroms.Compose(
    [
        tv_transfroms.Resize(output_h_w),
        tv_transfroms.Lambda(custom_transforms.one_hot),
    ]
)
# some checks
assert output_h_w == [520, 520] # this is the size for this case

#### Prepare Datasets

Hou

In [8]:
# Excluding this directory because the masks appear to be wrong
EXCLUDE_DIRS = (Path('data/Hou/PV03_Ground_WaterSurface'),)
hou_ds = make_hou_dataset(
    EXCLUDE_DIRS,
    img_transforms=input_transforms,
    mask_transforms=mask_transforms,
)

# Split datasets
train_ds, val_ds, test_ds = random_split(hou_ds, [0.6, 0.2, 0.2], generator=rng)

Kasmi ign

In [9]:
# ign_ds = make_kasmi_ign_dataset(
#     img_transforms=input_transforms,
#     mask_transforms=mask_transforms,
# )

# # Split datasets
# # Will test on the same dataset which was pickled last time
# train_ds, val_ds = random_split(ign_ds, [0.8, 0.2], generator=rng)

In [10]:
# Loading the data
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, generator=rng)
val_loader = DataLoader(val_ds, BATCH_SIZE)

In [11]:
print(len(train_loader))

44


#### Training

In [12]:
loss_fn = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters())

In [13]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss_mean = 0.

    # Here, we use enumerate(train_loader) instead of
    # iter(train_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, batch in enumerate(train_loader):
        inputs, labels = batch
        # move data to GPU
        inputs = inputs.to(device)
        labels = labels.to(device)


        # Zero your gradients for every batch!
        optimiser.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)['out']

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimiser.step()

        # Gather data and report
        running_loss += loss.item()
        if i % N_BATCHES_FOR_RECORD == N_BATCHES_FOR_RECORD - 1:
            last_loss_mean = running_loss / N_BATCHES_FOR_RECORD # loss per batch
            print(f'\tBatches {i-N_BATCHES_FOR_RECORD+2}-{i+1} mean loss: {last_loss_mean}')
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss_mean, tb_x)
            running_loss = 0.

    return last_loss_mean

In [14]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/ResNet-50{}'.format(timestamp))
epoch_number = 0

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_loader):
            vinputs, vlabels = vdata
            # move to GPU
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            voutputs = model(vinputs)['out']
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'saved_models/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
	Batches 1-10 mean loss: 0.47139993906021116
	Batches 11-20 mean loss: 0.302275624871254
	Batches 21-30 mean loss: 0.2746844723820686
	Batches 31-40 mean loss: 0.23926840275526046
LOSS train 0.23926840275526046 valid 0.229450985789299
EPOCH 2:
	Batches 1-10 mean loss: 0.20199502259492874
	Batches 11-20 mean loss: 0.216423699259758
	Batches 21-30 mean loss: 0.19614199846982955
	Batches 31-40 mean loss: 0.1848228767514229
LOSS train 0.1848228767514229 valid 0.22434663772583008
EPOCH 3:
	Batches 1-10 mean loss: 0.19168491363525392
	Batches 11-20 mean loss: 0.18057887256145477
	Batches 21-30 mean loss: 0.1842053860425949
	Batches 31-40 mean loss: 0.17903000712394715
LOSS train 0.17903000712394715 valid 0.17176111042499542
EPOCH 4:
	Batches 1-10 mean loss: 0.1735008955001831
	Batches 11-20 mean loss: 0.15665979757905008
	Batches 21-30 mean loss: 0.1480630874633789
	Batches 31-40 mean loss: 0.1437992848455906
LOSS train 0.1437992848455906 valid 0.16064895689487457
EPOCH 5:
	Batches 

In [15]:
if test_ds:
    with open(f'saved_models/test_data/test_ds_{timestamp}.pkl', 'wb') as f:
        pickle.dump(test_ds, f)