In [1]:
import sys
import os
import datetime as dt
import time
import pickle
import torch
from torch import nn
from torch.utils.data import random_split, DataLoader
import torch.nn.functional as F

# TODO: change path name
sys.path.append("/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/preprocessing/")
from preprocessing import EuroSATDataset

%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# TODO: change path name
# setting paths to EuroSAT data and preprocessing statistics
data_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/EuroSAT_RGB'
preprocessing_stats_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/preprocessing/preprocessing_stats.pkl'
checkpoint_path = '/mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/checkpoints'

In [4]:
# getting eurosat dataset
eurosat = EuroSATDataset(data_path, preprocessing_stats_path, transform=True)
classes = eurosat.sorted_class_names

In [5]:
# splitting dataset into train, validation, and test
generator = torch.Generator().manual_seed(0)
train_val_set, test_set = random_split(eurosat, [0.8, 0.2], generator = generator)
train_set, val_set = random_split(train_val_set, [0.8, 0.2], generator = generator)

In [6]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3)
        self.pooling = nn.MaxPool2d(2, 2)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x) 
        x = self.pooling(x)
        x = self.batch_norm(x) 
        x = self.relu(x) 
        return x

In [7]:
class FullyConnectedBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_output=False):
        super().__init__()
        self.is_output = is_output
        self.conv = nn.Linear(in_channels, out_channels)
        if not self.is_output:
            self.batch_norm = nn.BatchNorm1d(out_channels)
            self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if not self.is_output: 
            x = self.batch_norm(x)
            x = self.relu(x)
        return x

In [8]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_block1 = ConvBlock(3, 8)
        self.conv_block2 = ConvBlock(8, 12)
        self.conv_block3 = ConvBlock(12, 16)
        flatten_channels = 16 * 6 * 6
        self.fc_block1 = FullyConnectedBlock(flatten_channels, flatten_channels // 2)
        self.fc_block2 = FullyConnectedBlock(flatten_channels // 2, 10, is_output=True)
        self.conv_blocks = nn.Sequential(
            self.conv_block1, 
            self.conv_block2, 
            self.conv_block3
        )
        self.fc_blocks = nn.Sequential(
            self.fc_block1, 
            self.fc_block2
        )

    def forward(self, x):
        x = self.conv_blocks(x)
        #  flatten all dimensions except batch
        x = torch.flatten(x, 1) 
        x = self.fc_blocks(x)
        return x
    
test_model = Net()
test_img = torch.rand((10, 3, 64, 64))
test_model(test_img).shape

torch.Size([10, 10])

In [9]:
# Initializing model
model_name = 'cnn' # name of model (for checkpoint file name)
model = Net().to(device)

In [10]:
# TODO: setting hyperparameters
batch_size = 64
epochs = 20 
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

# containers for storing loss data and epoch time
train_loss = []
train_loss_idx = []
val_loss = []
val_loss_idx = []
epoch_times = []

# tracking when to checkpoint model
checkpoint_after_epochs = 5

In [11]:
# creating dataloaders
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle=True)

In [12]:
# Testing loss
for data in train_loader:
    imgs = data['image'].to(device)
    labels = data['land_use'].to(device)
    print(imgs.shape)
    print(imgs.dtype)
    print(labels.shape)
    print(labels.dtype)
    test_output = model(imgs)
    print(test_output.dtype)
    print(torch.nn.CrossEntropyLoss()(test_output, labels))
    break

torch.Size([64, 3, 64, 64])
torch.float32
torch.Size([64])
torch.int64
torch.float32
tensor(2.3804, device='cuda:0', grad_fn=<NllLossBackward0>)


In [None]:
def train_one_epoch(epoch_index, train_loss, train_loss_idx, optimizer, loss_fn, train_loader, model):
    running_loss = 0.
    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        # Every data instance is an input + label pair
        inputs = data['image'].to(device)
        labels = data['land_use'].to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 90 == 89:
            last_loss = running_loss / 89 # loss per batch
            timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
            print('{} batch {} loss: {}'.format(timestamp, i + 1, last_loss))
            train_loss_idx.append(epoch_index * len(train_loader) + i + 1)
            train_loss.append(last_loss)
            running_loss = 0.

    return last_loss

In [14]:
def save_model(epoch, optimizer, loss_fn, model, 
               train_loss, train_loss_idx, val_loss, val_loss_idx, status):
    model_path = os.path.join(checkpoint_path, f'{status}_{model_name}_e{epoch}')
    result = {
        'epoch': epoch, 
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_fn': loss_fn, 
        'model_state_dict': model.state_dict(),
        'train_loss': train_loss, 
        'train_loss_idx': train_loss_idx, 
        'val_loss': val_loss, 
        'val_loss_idx': val_loss_idx
    }
    timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
    print(f'{timestamp} Saving results at {checkpoint_path}')
    torch.save(result, model_path)

In [None]:
def train_model(epochs, train_loss, train_loss_idx, val_loss, val_loss_idx, 
                optimizer, loss_fn, train_loader, model):
    best_vloss = torch.inf 
    for epoch in range(epochs):
        timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
        print(f"{timestamp} Epoch {epoch}/{epochs}")

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        epoch_start_time = time.time()
        avg_loss = train_one_epoch(epoch, train_loss, train_loss_idx, optimizer, loss_fn, train_loader, model)
        epoch_end_time = time.time()
        epoch_times.append(epoch_end_time - epoch_start_time)
        timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
        print(f"{timestamp} Finished training in {str(dt.timedelta(seconds = epoch_times[-1]))}")

        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(val_loader):
                vinputs = vdata['image'].to(device)
                vlabels = vdata['land_use'].to(device)
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
        print('{} LOSS train {} valid {}'.format(timestamp, avg_loss, avg_vloss))

        # Log the validation running loss averaged per batch
        val_loss_idx.append(epoch * len(train_loader) + 1)
        val_loss.append(avg_vloss)

        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            timestamp = dt.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
            print(f"{timestamp} New best validation loss: {best_vloss}")
            save_model(epoch, optimizer, loss_fn, model, 
                       train_loss, train_loss_idx, val_loss, val_loss_idx, 'best')
        elif epoch % checkpoint_after_epochs == 0:
            save_model(epoch, optimizer, loss_fn, model, 
                       train_loss, train_loss_idx, val_loss, val_loss_idx, 'latest')
        
        print('=================================')

    save_model(epoch, optimizer, loss_fn, model, train_loss, train_loss_idx, val_loss, val_loss_idx, 'latest')

In [16]:
train_model(epochs, train_loss, train_loss_idx, val_loss, val_loss_idx, optimizer, loss_fn, train_loader, model)

2025-03-07 17-13-02 Epoch 0/20
2025-03-07 17-23-21 Finished training in 0:10:19.697917
2025-03-07 17-25-55 LOSS train 0.0 valid 6.4099507331848145
2025-03-07 17-25-55 New best validation loss: inf
2025-03-07 17-25-55 Saving results at /mnt/c/Users/brian/Documents/UCLA/2024-2025/Winter/Math_156/Final_Project/checkpoints
2025-03-07 17-25-55 Epoch 1/20


KeyboardInterrupt: 

In [None]:
# TODO: grid search specific to CNN, googlenet, mobilenet