In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Dependencies:

In [None]:
!pip install torch torchvision numpy

In [None]:
import torch
import torch.nn.functional as F
import torch.utils.data as utils
import torch.optim.lr_scheduler as lr_scheduler
from torchvision.transforms import v2

import os
import numpy as np
import shutil

## Model Initialization:

## Configuration:

In [None]:
# Separate configuration and model variables.

config_vars = {
    # 'kaggle_token': [], # In [USER, PASS] format.
    'drive_path': '',
    'dataset_path': 'kevingeng07/landseg-net-3-country-clc-data/versions/1',
    'save_path': '', # Google Drive directory to save the file.
    'train_bs': 256,
    'val_bs': 256,
    'test_bs': 512,
    'device': 'cuda',
    'loss_fn': torch.nn.CrossEntropyLoss(),
    'epochs': 100,
    'model': model, # Add the initialized model.
    'optimizer': torch.optim.AdamW,
    'optimizer_params': {'lr': 1e-3, 'weight_decay': 1e-2},
    'scheduler_params': {
        'schedulers': [
            (lr_scheduler.LinearLR, {'start_factor': 5e-5/1e-3, 'total_iters': 5}),
            (lr_scheduler.CosineAnnealingLR, {'T_max': 95})
        ],
        'milestones': [5]
    },
    'patience': 10
}

## Dataset:

In [None]:
### TODO: CREATE NEW CONFIG VARIABLE "drive_path" FOR THE PARENT DIRECTORY OF THE DATA.

def create_dataset(**kwargs):
  # kaggle_token = kwargs.get('kaggle_token')
  drive_path = kwargs.get('drive_path') # Path to drive folder that contains the dataset, not the dataset folder itself.
  dataset_path = kwargs.get('dataset_path')

  # os.environ['KAGGLE_USERNAME'] = kaggle_token[0]
  # os.environ['KAGGLE_KEY'] = kaggle_token[1]

  # kagglehub.login()
  # kagglehub.dataset_download(dataset_path)


  class DatasetGenerator(utils.Dataset):
    def __init__(self, split, drive_path, dataset_path):

      self.split = split

      self.data_dir = f'/content/{dataset_path}'
      self.data_paths = []

      shutil.copytree(f'/content/drive/{drive_path}/{dataset_path}', self.data_dir)
      for file_path in os.listdir(os.path.join(self.data_dir, split)):
        self.data_paths.append(os.path.join(self.data_dir, split, file_path))

    def __len__(self):
      return len(self.data_paths)

    def __getitem__(self, idx):
      data = np.load(self.data_paths[idx])

      if self.split == 'train':
        data = self.augment_data(data)

      features = data[:-1, :, :].astype(np.float32) / 10000 # Correctly scale the data according to the raw reflectance values.
      labels = data[-1, :, :].astype(np.uint16)

      labels[np.isin(labels, [111, 112, 121, 122, 123, 124, 131, 132, 133, 141, 142, 211, 212, 213, 221, 222, 223, 231, 241, 242, 243, 244, 323, 331, 332, 334, 421, 422, 423, 521, 522, 523, 999])] = 0

      # Natural land cover.
      unique_classes = [0, 311, 312, 313, 321, 322, 324, 333, 411, 412, 511, 512]
      mapping = {v: i for i, v in enumerate(unique_classes)}
      max_class_value = unique_classes.max()
      lookup = np.full(max_class_value + 1, -1, dtype=int)
      for k, v in mapping.items():
          lookup[k] = v
      labels = lookup[labels]

      return features, labels

    def augment_data(self, x):
      transforms_list = []

      if np.random.random() <= 0.5:
        transforms_list.append(v2.RandomHorizontalFlip(p=1.0))

      if np.random.random() <= 0.5:
        transforms_list.append(v2.RandomVerticalFlip(p=1.0))

      if np.random.random() <= 0.25:
        transforms_list.append(v2.RandomRotation(degrees=135))

      if np.random.random() <= 0.15:
        transforms_list.append(v2.GaussianNoise(sigma=0.15, clip=True)) # Reduce simga to 0.05 first if underfitting, or try different sigmas based on band mean (ie. B01-02 has lower values for some samples).

      return v2.Compose(transforms_list)(x)


  train_set = DatasetGenerator('train')
  validation_set = DatasetGenerator('validation')
  test_set = DatasetGenerator('test')

  return utils.DataLoader(train_set, batch_size=kwargs.get('train_bs'), shuffle=True, drop_last=True, num_workers=8, persistent_workers=True, pin_memory=True),
  utils.DataLoader(validation_set, batch_size=kwargs.get('val_bs'), shuffle=False, drop_last=True, num_workers=8, persistent_workers=True, pin_memory=True),
  utils.DataLoader(test_set, batch_size=kwargs.get('test_bs'), shuffle=False)

## Training:

In [None]:
class EarlyStopping:
  def __init__(self, patience, tol=0.01):
    self.patience = patience
    self.tol = tol

    self.cur_epoch = 0
    self.lowest_val_loss = float(0)
    self.best_model = None


  def __call__(self, val_loss, model):
    if (val_loss <= self.lowest_val_loss-self.tol or self.lowest_val_loss == 0):
      self.cur_epoch = 0
      self.best_model = model.state_dict()
      self.lowest_val_loss = val_loss

    if self.cur_epoch == self.patience:
      return self.best_model

    return None

In [None]:
def train_model(train_loader, val_loader, **kwargs):
  assert kwargs.get('optimizer') # Either a LR scheduler, or optimizer, must exist for the model.

  device = kwargs.get('device')
  patience = kwargs.get('patience')
  save_path = kwargs.get('save_path')


  model = kwargs.get('model').to(device)
  model.init_weights()


  optimizer = kwargs.get('optimizer')(model.parameters(), **kwargs.get('optimizer_params'))
  if scheduler:
    scheduler = lr_scheduler.SequentialLR(optimizer, **kwargs.get('scheduler_params'))

  loss_fn = kwargs.get('loss_fn')

  if patience:
    early_stop = EarlyStopping(patience=patience)

  for epoch in range(1, kwargs.get('epochs')+1):
    train_loss = float(0)

    model.train(mode=True)
    for data in train_loader:
      input_data, label = data
      input_data = input_data.to(device)
      label = label.to(device)

      optimizer.zero_grad()

      with torch.autocast(device, dtype=torch.bfloat16):
        pred_label = model(input_data)
        loss = loss_fn(pred_label, label)

      train_loss += loss.detach().item()

      loss.backward()
      optimizer.step()

    if (epoch - 1) % 5 == 0:
      torch.save(model.state_dict(), f'{save_path}/epoch_{epoch}.pth')


    model.train(mode=False)
    with torch.zero_grad():
      val_loss = float(0)

      for data in val_loader:
        input_data, label = data
        input_data_val = input_data.to(device)
        label_val = label.to(device)

        pred_label_val = model(input_data_val)
        loss = loss_fn(pred_label_val, label_val)

        val_loss += loss.detach().item()

      train_loss /= len(train_loader)
      val_loss /= len(val_loader)

      if patience:
        es_result = early_stop(val_loss)
        if es_result:
          torch.save(es_result, f'{save_path}/BEST_epoch_{epoch-patience}.pth')
          break # Break after early stopping is applied.

    if scheduler:
      scheduler.step()

  print(f'Epoch {epoch}:\n Train Loss: {train_loss}\n Val Loss: {val_loss}')

## Main:

In [None]:
def training_pipeline(**kwargs):
  train_data, val_data, test_data = create_dataset(**kwargs)
  train_model(train_data, val_data, **kwargs)

In [None]:
if __name__ == '__main__':
  training_pipeline(**config_vars)