In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Dependencies:

In [2]:
!pip install torch torchvision numpy torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collectin

In [3]:
import torch
import torch.nn as nn
import torch.utils.data as utils
import torch.optim.lr_scheduler as lr_scheduler
from torchvision.transforms import v2
import torchinfo

import os
import numpy as np
import shutil

## Model Initialization:

## Light U-Net:

In [4]:
class FeatureConvs(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, padding, dropout=0.0):
    super().__init__()

    self.feature_convs = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=padding, bias=False),
        nn.BatchNorm2d(num_features=in_channels),
        nn.ReLU(),

        nn.Dropout(p=dropout),

        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU()
    )

  def forward(self, x): return self.feature_convs(x)


class MLPBlock(nn.Module):
  def __init__(self, in_dim, out_dim):
    super().__init__()

    self.mlp_block = nn.Sequential(
        nn.Linear(in_features=in_dim, out_features=out_dim, bias=False),
        nn.LayerNorm(normalized_shape=out_dim),
        nn.ReLU()
    )

  def forward(self, x): return self.mlp_block(x)


class Encoder(nn.Module):
  def __init__(self, layer, kwargs):
    super().__init__()
    self.layer = layer

    self.in_channels = kwargs.get('in_channels')[layer]
    self.out_channels = kwargs.get('out_channels')[layer]
    self.kernel_size = kwargs.get('kernel_size')
    self.padding = kwargs.get('padding')
    self.dropout = kwargs.get('dropout')

    self.pool_stride = kwargs.get('pool_stride')

    self.feature_convs = FeatureConvs(self.in_channels, self.out_channels, self.kernel_size, self.padding, self.dropout)

  def forward(self, x):
    if self.layer != 0:
      x = nn.functional.max_pool2d(x, kernel_size=self.pool_stride)

    return self.feature_convs(x)


class Decoder(nn.Module):
  def __init__(self, layer, kwargs):
    super().__init__()
    self.layer = layer
    self.layers = kwargs.get('layers')

    self.in_channels = kwargs.get('out_channels')[layer]
    self.out_channels = kwargs.get('in_channels')[layer]
    self.kernel_size = kwargs.get('kernel_size')
    self.padding = kwargs.get('padding')

    self.pool_stride = kwargs.get('pool_stride')
    self.pointwise = nn.Conv2d(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1)

    self.feature_convs = FeatureConvs(self.in_channels, self.out_channels, self.kernel_size, self.padding)

  def forward(self, x, _x=None):
    x = self.pointwise(nn.functional.interpolate(x, scale_factor=self.pool_stride))
    x += _x

    return self.feature_convs(x)


class Bottleneck(nn.Module):
  def __init__(self, kwargs):
    super().__init__()

    self.dim_list = kwargs.get('mlp_dims')

    self.mlp = nn.ModuleList([MLPBlock(i, o) for i, o in zip(self.dim_list[:-1], self.dim_list[1:])])

    self.process_conv = nn.Sequential(
        nn.Conv2d(in_channels=kwargs.get('out_channels')[-1], out_channels=kwargs.get('in_channels')[-1], kernel_size=kwargs.get('kernel_size'), padding=kwargs.get('padding'), bias=False),
        nn.BatchNorm2d(num_features=kwargs.get('in_channels')[-1]),
        nn.ReLU()
    )

  def forward(self, x):
    x = x.permute(0, 2, 3, 1)

    for i in self.mlp:
      x = i(x)

    x = x.permute(0, 3, 1, 2)
    return self.process_conv(x)

In [5]:
class Light_UNet(nn.Module):
  def __init__(self, **kwargs):
    super().__init__()

    self.enc_1 = Encoder(0, kwargs)
    self.enc_2 = Encoder(1, kwargs)
    self.enc_3 = Encoder(2, kwargs)
    self.enc_4 = Encoder(3, kwargs)

    self.dec_3 = Decoder(2, kwargs)
    self.dec_2 = Decoder(1, kwargs)
    self.dec_1 = Decoder(0, kwargs)

    self.bottleneck = Bottleneck(kwargs)

    self.first_conv = nn.Sequential(
        nn.Conv2d(in_channels=kwargs.get('feature_labels')[0], out_channels=kwargs.get('in_channels')[0], kernel_size=kwargs.get('kernel_size'), padding=kwargs.get('padding'), bias=False),
        nn.BatchNorm2d(num_features=kwargs.get('in_channels')[0]),
        nn.ReLU()
    )

    self.last_conv = nn.Conv2d(in_channels=kwargs.get('out_channels')[0], out_channels=kwargs.get('feature_labels')[-1], kernel_size=1)

  def forward(self, x):
    x_1 = self.enc_1(self.first_conv(x))
    x_2 = self.enc_2(x_1)
    x_3 = self.enc_3(x_2)
    x_4 = self.bottleneck(self.enc_4(x_3))

    x_3 = self.dec_3(x_4, x_3)
    x_2 = self.dec_2(x_3, x_2)
    x_1 = self.dec_1(x_2, x_1)

    return self.last_conv(x_1)

  def init_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')

In [6]:
model_config = {
    'layers': 4,
    'kernel_size': 3,
    'padding': 'same',
    'dropout': 0.1,
    'pool_stride': 2,
    'feature_labels': [12, 20], # Adjust element 1 based on number of valid labels for task.
    'in_channels': [64, 64, 128, 256],
    'out_channels': [64, 128, 256, 512],
    'mlp_dims': [512, 1024, 1024, 512]
}

In [7]:
torchinfo.summary(Light_UNet(**model_config), input_size=((1, 12, 120, 120)))

Layer (type:depth-idx)                   Output Shape              Param #
Light_UNet                               [1, 20, 120, 120]         --
├─Sequential: 1-1                        [1, 64, 120, 120]         --
│    └─Conv2d: 2-1                       [1, 64, 120, 120]         6,912
│    └─BatchNorm2d: 2-2                  [1, 64, 120, 120]         128
│    └─ReLU: 2-3                         [1, 64, 120, 120]         --
├─Encoder: 1-2                           [1, 64, 120, 120]         --
│    └─FeatureConvs: 2-4                 [1, 64, 120, 120]         --
│    │    └─Sequential: 3-1              [1, 64, 120, 120]         73,984
├─Encoder: 1-3                           [1, 128, 60, 60]          --
│    └─FeatureConvs: 2-5                 [1, 128, 60, 60]          --
│    │    └─Sequential: 3-2              [1, 128, 60, 60]          110,976
├─Encoder: 1-4                           [1, 256, 30, 30]          --
│    └─FeatureConvs: 2-6                 [1, 256, 30, 30]          --
│ 

## Configuration:

In [8]:
# Separate configuration and model variables.

config_vars = {
    # 'kaggle_token': [], # In [USER, PASS] format.
    'drive_path': 'MyDrive/Semantic Segmentation/data/new_natural_ecosystems',
    'save_path': 'MyDrive/Semantic Segmentation/light u-net chpts', # Google Drive directory to save the file.
    'train_bs': 256,
    'val_bs': 256,
    'test_bs': 512,
    'device': 'cuda',
    'loss_fn': torch.nn.CrossEntropyLoss(),
    'epochs': 100,
    'model': Light_UNet(**model_config), # Add the initialized model.
    'optimizer': torch.optim.AdamW,
    'optimizer_params': {'lr': 1e-3, 'weight_decay': 1e-2},
    'scheduler_params': {
        'schedulers': [
            (lr_scheduler.LinearLR, {'start_factor': 5e-5/1e-3, 'total_iters': 5}),
            (lr_scheduler.CosineAnnealingLR, {'T_max': 95})
        ],
        'milestones': [5]
    },
    'patience': 10
}

## Dataset:

In [9]:
def create_dataset(**kwargs):
  drive_path = kwargs.get('drive_path') # Path to dataset folder.

  class DatasetGenerator(utils.Dataset):
    def __init__(self, split):

      self.split = split

      shutil.copyfile(f'/content/drive/{drive_path}/{split}.npy', f'/content/{split}.npy')
      self.data = np.load(f'/content/{split}.npy')

      # Natural land cover.
      unique_classes = [0, 311, 312, 313, 321, 322, 324, 333, 411, 412, 511, 512]
      mapping = {v: i for i, v in enumerate(unique_classes)}
      max_class_value = max(unique_classes)
      self.lookup = np.full(max_class_value + 1, -1, dtype=int)
      for k, v in mapping.items():
          self.lookup[k] = v


    def __len__(self):
      return len(self.data)

    def __getitem__(self, idx):
      data = self.data[idx]

      if self.split == 'train':
        data = self.augment_data(data)

      features = data[:-1, :, :].astype(np.float32) / 10000 # Correctly scale the data according to the raw reflectance values.
      labels = data[-1, :, :].astype(np.uint16)

      labels[np.isin(labels, [111, 112, 121, 122, 123, 124, 131, 132, 133, 141, 142, 211, 212, 213, 221, 222, 223, 231, 241, 242, 243, 244, 323, 331, 332, 334, 421, 422, 423, 521, 522, 523, 999])] = 0
      labels = self.lookup[labels]

      return features, labels

    def augment_data(self, x):
      transforms_list = []

      if np.random.random() <= 0.5:
        transforms_list.append(v2.RandomHorizontalFlip(p=1.0))

      if np.random.random() <= 0.5:
        transforms_list.append(v2.RandomVerticalFlip(p=1.0))

      if np.random.random() <= 0.25:
        transforms_list.append(v2.RandomRotation(degrees=135))

      if np.random.random() <= 0.15:
        transforms_list.append(v2.GaussianNoise(sigma=0.15, clip=True)) # Reduce simga to 0.05 first if underfitting, or try different sigmas based on band mean (ie. B01-02 has lower values for some samples).

      if transforms_list:
        return v2.Compose(transforms_list)(x)
      return x


  train_set = DatasetGenerator('train')
  print('Train Copied')
  validation_set = DatasetGenerator('validation')
  print('Validation Copied')
  test_set = DatasetGenerator('test')
  print('Test Copied')

  return (utils.DataLoader(train_set, batch_size=kwargs.get('train_bs'), shuffle=True, drop_last=True, num_workers=8, persistent_workers=True, pin_memory=True),
          utils.DataLoader(validation_set, batch_size=kwargs.get('val_bs'), shuffle=False, drop_last=True, num_workers=8, persistent_workers=True, pin_memory=True),
          utils.DataLoader(test_set, batch_size=kwargs.get('test_bs'), shuffle=False))

## Training:

In [10]:
class EarlyStopping:
  def __init__(self, patience, tol=0.01):
    self.patience = patience
    self.tol = tol

    self.cur_epoch = 0
    self.lowest_val_loss = float(0)
    self.best_model = None


  def __call__(self, val_loss, model):
    if (val_loss <= self.lowest_val_loss-self.tol or self.lowest_val_loss == 0):
      self.cur_epoch = 0
      self.best_model = model.state_dict()
      self.lowest_val_loss = val_loss
    else:
      self.cur_epoch += 1 # Increment counter if no improvement

    if self.cur_epoch == self.patience:
      return self.best_model

    return None

In [11]:
def train_model(train_loader, val_loader, **kwargs):
  device = kwargs.get('device')
  patience = kwargs.get('patience')
  save_path = kwargs.get('save_path')


  model = kwargs.get('model').to(device)
  model.init_weights()


  optimizer = kwargs.get('optimizer')(model.parameters(), **kwargs.get('optimizer_params'))
  scheduler = None
  if kwargs.get('scheduler_params'):
    scheduler_details = kwargs.get('scheduler_params')
    schedulers_list = [scheduler_class(optimizer, **params) for scheduler_class, params in scheduler_details.get('schedulers')]
    scheduler = lr_scheduler.SequentialLR(optimizer, schedulers=schedulers_list, milestones=scheduler_details.get('milestones'))

  loss_fn = kwargs.get('loss_fn')

  if patience:
    early_stop = EarlyStopping(patience=patience)

  for epoch in range(1, kwargs.get('epochs')+1):
    train_loss = float(0)

    model.train(mode=True)
    for data in train_loader:
      input_data, label = data
      input_data = input_data.to(device)
      label = label.to(device)

      # for ch in range(12):
      #   print(ch, input_data[0, ch, ...].mean())

      # print(label[0, ...].unique())

      optimizer.zero_grad()

      with torch.autocast(device, dtype=torch.bfloat16):
        pred_label = model(input_data)
        loss = loss_fn(pred_label, label)

      train_loss += loss.detach().item()

      loss.backward()
      optimizer.step()

    if (epoch - 1) % 5 == 0:
      torch.save(model.state_dict(), f'/content/drive/{save_path}/epoch_{epoch}.pth')


    model.train(mode=False)
    with torch.no_grad():
      val_loss = float(0)

      for data in val_loader:
        input_data, label = data
        input_data_val = input_data.to(device)
        label_val = label.to(device)

        pred_label_val = model(input_data_val)
        loss = loss_fn(pred_label_val, label_val)

        val_loss += loss.detach().item()

      train_loss /= len(train_loader)
      val_loss /= len(val_loader)

      if patience and early_stop:
        es_result = early_stop(val_loss, model)
        if es_result:
          torch.save(es_result, f'/content/drive/{save_path}/BEST_epoch_{epoch-patience}.pth')
          break

    if scheduler:
      scheduler.step()

    print(f'Epoch {epoch}:\n Train Loss: {train_loss}\n Val Loss: {val_loss}')

## Main:

In [12]:
print(config_vars.get('drive_path'))

MyDrive/Semantic Segmentation/data/new_natural_ecosystems


In [13]:
def training_pipeline(**kwargs):
  # Set seeds for reproducibility.


  train_data, val_data, test_data = create_dataset(**kwargs)
  train_model(train_data, val_data, **kwargs)

In [None]:
if __name__ == '__main__':
  training_pipeline(**config_vars)

Train Copied
Validation Copied
Test Copied
Epoch 1:
 Train Loss: 1.8568679365781273
 Val Loss: 1.975659940553748
Epoch 2:
 Train Loss: 0.9423453129164063
 Val Loss: 0.8115009833937106
Epoch 3:
 Train Loss: 0.7131560102547749
 Val Loss: 0.7660000751847806
Epoch 4:
 Train Loss: 0.6290817803675586
 Val Loss: 0.7815427650576052




Epoch 5:
 Train Loss: 0.5819325718549219
 Val Loss: 0.8112700095643168
