In [None]:
import sys
import os
import datetime as dt
import pickle
import torch
from torch.utils.data import random_split, DataLoader

# TODO: change path name
sys.path.append("/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/preprocessing/")
from preprocessing import EuroSATDataset

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# TODO: change path name
# setting paths to EuroSAT data and preprocessing statistics
data_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/EuroSAT_RGB'
preprocessing_stats_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/preprocessing/preprocessing_stats.pkl'
checkpoint_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/checkpoints'

In [None]:
# getting eurosat dataset
eurosat = EuroSATDataset(data_path, preprocessing_stats_path, transform=True)
classes = eurosat.sorted_class_names

In [None]:
# splitting dataset into train, validation, and test
generator = torch.Generator().manual_seed(0)
train_val_set, test_set = random_split(eurosat, [0.8, 0.2], generator = generator)
train_set, val_set = random_split(train_val_set, [0.8, 0.2], generator = generator)

In [None]:
# TODO: setting hyperparameters
batch_size = 64
epochs = None 
optimizer = None
loss_fn = None

# containers for storing loss data
train_loss_idx = []
train_loss = []
val_loss_idx = []
val_loss = []

# TODO: initializing model
# name of model (for checkpoint file name)
model_name = 'cnn'
model = None

In [54]:
# creating dataloaders
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle=True)

In [None]:
def train_one_epoch(epoch_index, optimizer, loss_fn, train_loader, model):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        # Every data instance is an input + label pair
        inputs = data['image']
        labels = data['land_use']

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 3000 == 2999:
            last_loss = running_loss / 3000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            train_loss_idx.append(epoch_index * len(train_loader) + i + 1)
            train_loss.append(last_loss)
            running_loss = 0.

    return last_loss

In [None]:
def train_model(epochs, train_loss, train_loss_idx, val_loss, val_loss_idx, 
                optimizer, loss_fn, train_loader, model):
    best_vloss = torch.inf 
    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        avg_loss = train_one_epoch(epoch, train_loss, train_loss_idx, optimizer, loss_fn, train_loader, model)

        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(val_loader):
                vinputs = vdata['image']
                vlabels = vdata['land_use']
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

        # Log the validation running loss averaged per batch
        val_loss_idx.append(epoch * len(train_loader) + 1)
        val_loss.append(avg_vloss)

        # TODO: Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            print(f"New best validation loss: {best_vloss}")
            best_vloss = avg_vloss
            timestamp = dt.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
            model_path = os.path.join(checkpoint_path, f'{model_name}_{timestamp}_e{epoch}')
            result = {
                'epoch': epoch, 
                'optimizer_state_dict': optimizer.state_dict(),
                'loss_fn': loss_fn, 
                'model_state_dict': model.state_dict(),
                'train_loss': train_loss, 
                'train_loss_idx': train_loss_idx, 
                'val_loss': val_loss, 
                'val_loss_idx': val_loss_idx
            }
            print(f'Saving results at {checkpoint_path}')
            torch.save(result, model_path)

    # save losses
    final_losses = {
        'train_loss': train_loss, 
        'train_loss_idx': train_loss_idx, 
        'val_loss': val_loss,
        'val_loss_idx': val_loss_idx
    }
    with open('losses.pkl', 'wb') as f:
        print(f"Saving losses at {os.path.join(os.getcwd(), 'losses.pkl')}")
        pickle.dump(final_losses, f)

    return final_losses

In [None]:
# TODO: grid search specific to CNN, googlenet, mobilenet