In [None]:
! pip install visdom==0.1.7 wandb rasterio
!pip install einops
!pip install einsum

In [None]:
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import autograd
import torch.optim as optim

import os
from glob import glob
from datetime import datetime
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.model_selection import GroupShuffleSplit
from matplotlib import pyplot as plt
plt.switch_backend('agg')

import visdom
from torchvision.utils import save_image, make_grid
import datetime
import pickle
from tqdm import tqdm
import itertools
import wandb
import math
from einops import rearrange

import argparse
import pandas as pd
import time
import random
from datetime import datetime
import rasterio
from pathlib import Path

from PIL import Image
import torchvision.transforms as transforms
import torch.nn.utils.rnn as rnn
import json
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
from torchvision import models
import warnings
warnings.filterwarnings('ignore')

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

In [None]:
# BAND STATS

BANDS = { 's1': { 'VV': 0, 'VH': 1, 'RATIO': 2},
          's2': { '10': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'RDED1': 3, 'RDED2': 4, 'RDED3': 5, 'NIR': 6, 'RDED4': 7, 'SWIR1': 8, 'SWIR2': 9},
                   '4': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'NIR': 3}},
          'planet': { '4': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'NIR': 3}}}

MEANS = { 's1': { 'ghana': torch.Tensor([-10.50, -17.24, 1.17]),
                  'southsudan': torch.Tensor([-9.02, -15.26, 1.15])},
          's2': { 'ghana': torch.Tensor([2620.00, 2519.89, 2630.31, 2739.81, 3225.22, 3562.64, 3356.57, 3788.05, 2915.40, 2102.65]),
                  'southsudan': torch.Tensor([2119.15, 2061.95, 2127.71, 2277.60, 2784.21, 3088.40, 2939.33, 3308.03, 2597.14, 1834.81])},
          'planet': { 'ghana': torch.Tensor([1264.81, 1255.25, 1271.10, 2033.22]),
                      'southsudan': torch.Tensor([1091.30, 1092.23, 1029.28, 2137.77])},
          's2_cldfltr': { 'ghana': torch.Tensor([1362.68, 1317.62, 1410.74, 1580.05, 2066.06, 2373.60, 2254.70, 2629.11, 2597.50, 1818.43]),
                  'southsudan': torch.Tensor([1137.58, 1127.62, 1173.28, 1341.70, 1877.70, 2180.27, 2072.11, 2427.68, 2308.98, 1544.26])} }

STDS = { 's1': { 'ghana': torch.Tensor([3.57, 4.86, 5.60]),
                 'southsudan': torch.Tensor([4.49, 6.68, 21.75])},
         's2': { 'ghana': torch.Tensor([2171.62, 2085.69, 2174.37, 2084.56, 2058.97, 2117.31, 1988.70, 2099.78, 1209.48, 918.19]),
                 'southsudan': torch.Tensor([2113.41, 2026.64, 2126.10, 2093.35, 2066.81, 2114.85, 2049.70, 2111.51, 1320.97, 1029.58])},
         'planet': { 'ghana': torch.Tensor([602.51, 598.66, 637.06, 966.27]),
                     'southsudan': torch.Tensor([526.06, 517.05, 543.74, 1022.14])},
         's2_cldfltr': { 'ghana': torch.Tensor([511.19, 495.87, 591.44, 590.27, 745.81, 882.05, 811.14, 959.09, 964.64, 809.53]),
                 'southsudan': torch.Tensor([548.64, 547.45, 660.28, 677.55, 896.28, 1066.91, 1006.01, 1173.19, 1167.74, 865.42])} }

# OTHER PER COUNTRY CONSTANTS
NUM_CLASSES = { 'ghana': 4,
                'southsudan': 4}

GRID_SIZE = { 'ghana': 256,
              'southsudan': 256}


CROPS = { 'ghana': ['groundnut', 'maize', 'rice', 'soya bean'],
          'southsudan': ['sorghum', 'maize', 'rice', 'groundnut']}

In [None]:
class CropTypeMappingDataset():

    def __init__(self, ):

        self.data_dir = '/content/data/africa_crop_type_mapping' # 'gs://data_ctm/data/africa_crop_type_mapping'

        self.split_dict = {'train': 0, 'val': 1, 'test': 2}
        self.split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'}

        # Extract splits
        split_df = pd.read_csv(os.path.join(self.data_dir, 'ghana', 'list_eval_partition.csv'))

        self.split_array = split_df['partition'].values


        # y_array stores idx ids corresponding to location. Actual y labels are
        # tensors that are loaded separately.
        self.y_array = torch.from_numpy(split_df['id'].values)

        self.y_size = (64, 64)


    def __getitem__(self, idx):
        # Any transformations are handled by the SustainBenchSubset
        # since different subsets (e.g., train vs test) might have different transforms
        x = self.get_input(idx)
        y = self.get_label(idx)
        return x, y


    def get_input(self, idx):
        """
        Returns X for a given idx.
        """
        loc_id = f'{self.y_array[idx]:06d}'

        images = np.load(os.path.join(self.data_dir, 'ghana', 'npy', f'{"ghana"}_{loc_id}.npz'))
        dates_idx = self.get_dates(loc_id)

        s1 = images['s1']
        s2 = images['s2']
        planet = images['planet']

        s1 = torch.from_numpy(s1)
        s2 = torch.from_numpy(s2.astype(np.int32))
        planet = torch.from_numpy(planet.astype(np.int32))

        planet = planet.permute(3, 0, 1, 2)
        planet = transforms.CenterCrop(128)(planet)
        planet = planet.permute(1, 2, 3, 0)

        # Normalization
        s1 = self.normalization(s1, 's1')
        s2 = self.normalization(s2, 's2')
        planet = self.normalization(planet, 'planet')

        s1 = s1[:,:,:,dates_idx["s1_min"]:dates_idx["s1_max"]]
        s2 = s2[:,:,:,dates_idx["s2_min"]:dates_idx["s2_max"]]
        planet = planet[:,:,:,dates_idx["planet_min"]:dates_idx["planet_max"]]

        planet = torch.mean(planet, axis=-1)
        s1 = torch.mean(s1, axis=-1)
        s2 = torch.mean(s2, axis=-1)

        return {'s1': torch.tensor(s1, dtype=torch.float32), 's2': torch.tensor(s2, dtype=torch.float32), 'planet': torch.tensor(planet, dtype=torch.float32)}

    def get_label(self, idx):
        """
        Returns y for a given idx.
        """
        loc_id = f'{self.y_array[idx]:06d}'
        label = np.load(os.path.join(self.data_dir, 'ghana', 'truth', f'{"ghana"}_{loc_id}.npz'))['truth']
        label = torch.from_numpy(label)
        label[label>4]=0
        return label

    def get_dates(self, loc_id):
        """
        Converts json dates into tensor containing dates
        """
        s1_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 's1', f"s1_{'ghana'}_{loc_id}.json"), 'r').read())
        s1 = s1_json['dates']

        s1 =np.array([datetime.strptime(date, "%Y-%m-%d") for date in s1])
        s1_date_low = s1[s1 >= datetime.strptime('2016-05-01', "%Y-%m-%d")]
        s1_date_high = s1_date_low[s1_date_low<= datetime.strptime('2016-11-30', "%Y-%m-%d")]
        s1_idx_min = np.where(s1==s1_date_high[0])
        s1_idx_max = np.where(s1==s1_date_high[-1])
        s1_idx_min = s1_idx_min[0][0]
        s1_idx_max = s1_idx_max[0][0]


        s2_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 's2', f"s2_{'ghana'}_{loc_id}.json"), 'r').read())
        s2 = s2_json['dates']

        s2 =np.array([datetime.strptime(date, "%Y-%m-%d") for date in s2])
        s2_date_low = s2[s2 >= datetime.strptime('2016-05-01', "%Y-%m-%d")]
        s2_date_high = s2_date_low[s2_date_low<= datetime.strptime('2016-11-30', "%Y-%m-%d")]
        s2_idx_min = np.where(s2==s2_date_high[0])
        s2_idx_max = np.where(s2==s2_date_high[-1])
        s2_idx_min = s2_idx_min[0][0]
        s2_idx_max = s2_idx_max[0][0]

        planet_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 'planet', f"planet_{'ghana'}_{loc_id}.json"), 'r').read())
        planet = planet_json['dates']

        planet =np.array([datetime.strptime(date, "%Y-%m-%d") for date in planet])
        planet_date_low = planet[planet >= datetime.strptime('2017-05-01', "%Y-%m-%d")]
        planet_date_high = planet_date_low[planet_date_low<= datetime.strptime('2017-11-30', "%Y-%m-%d")]
        planet_idx_min = np.where(planet==planet_date_high[0])
        planet_idx_max = np.where(planet==planet_date_high[-1])
        planet_idx_min = planet_idx_min[0][0]
        planet_idx_max = planet_idx_max[0][0]

        return {"s1_min":s1_idx_min, "s1_max":s1_idx_max, "s2_min":s2_idx_min,"s2_max":s2_idx_max,"planet_min":planet_idx_min,"planet_max":planet_idx_max}

    def normalization(self, grid, satellite):
        """ Normalization based on values defined in constants.py
        Args:
          grid - (tensor) grid to be normalized
          satellite - (str) describes source that grid is from ("s1" or "s2")
        Returns:
          grid - (tensor) a normalized version of the input grid
        """
        num_bands = grid.shape[0]
        means = MEANS[satellite]['ghana']
        stds = STDS[satellite]['ghana']
        grid = (grid-means[:num_bands].reshape(num_bands, 1, 1, 1))/stds[:num_bands].reshape(num_bands, 1, 1, 1)

        if satellite not in ['s1', 's2', 'planet']:
            raise ValueError("Incorrect normalization parameters")
        return grid

In [None]:
class SustainBenchSubset(CropTypeMappingDataset):
    def __init__(self, dataset, split, transform=None):
        """
        This acts like torch.utils.data.Subset, but on SustainBenchDatasets.
        We pass in transform explicitly because it can potentially vary at
        training vs. test time, if we're using data augmentation.

        """
        super().__init__()
        self.dataset = dataset
        self.transform = transform

        split_mask = self.split_array == self.split_dict[split]
        self.indices = np.where(split_mask)[0]

    def __getitem__(self, idx):
        x, y= self.dataset[self.indices[idx]]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.indices)

In [None]:
dataset = CropTypeMappingDataset()

train_dataset = SustainBenchSubset(dataset, 'train')
val_dataset = SustainBenchSubset(dataset, 'val')


In [None]:
train_loader = DataLoader(
                train_dataset,
                shuffle=True, # Shuffle training dataset
                sampler=None,
                num_workers = 2,
                batch_size=9)

val_loader = DataLoader(
                val_dataset,
                shuffle=False, # Do not shuffle eval datasets
                sampler=None,
                num_workers = 2,
                batch_size=9)

In [None]:
for x,y in train_loader:
  print(x['planet'].shape)
  break

## 2D_UNET Model

In [None]:
def conv(ch_in, ch_out, kernel_size=3,stride=1, padding=1):

  return nn.Sequential(
      nn.Conv2d(in_channels=ch_in, out_channels=ch_in, kernel_size=kernel_size,
                               stride=stride, padding=padding),
      nn.BatchNorm2d(ch_in),
      nn.LeakyReLU(inplace=True),
      nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size,
                               stride=stride, padding=padding),
      nn.BatchNorm2d(ch_out),
      nn.LeakyReLU(inplace=True),
      nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=kernel_size,
                               stride=stride, padding=padding),
      nn.BatchNorm2d(ch_out),
      nn.LeakyReLU(inplace=True),
  )

def upconv(ch_in, ch_out, kernel_size=3,stride=1, padding=1):

  return nn.Sequential(
      nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=kernel_size,
                               stride=stride, padding=padding),
      nn.BatchNorm2d(ch_out),
      nn.LeakyReLU(inplace=True),

  )

In [None]:
class UNet_planet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder S2
        self.convs20 = conv(ch_in=10, ch_out=64)
        self.maxpool20 = nn.MaxPool2d(kernel_size=2, stride=2,padding=0)
        self.convs21_1_=conv(ch_in=64, ch_out=64)
        self.convs21_2=conv(ch_in=64, ch_out=64)

        self.maxpool21 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convs22_1=conv(ch_in=64, ch_out=128)
        self.convs22_2=conv(ch_in=128, ch_out=128)

        self.maxpool22 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convs23=conv(ch_in=128, ch_out=256)
        # decoder S2
        self.upconvs23 = upconv(ch_in=256, ch_out=256)
        self.transConvs23 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvs22_2=upconv(ch_in=256, ch_out=128) # 128+128, 128
        self.upconvs22_1=upconv(ch_in=128, ch_out=128)
        self.transConvs22 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvs21_2=upconv(ch_in=128, ch_out=64) # 64+64, 64
        self.upconvs21_1=upconv(ch_in=64, ch_out=32)
        self.upconvs2_nc=upconv(ch_in=32, ch_out=5)

        # Encoder S1
        self.convs10 = conv(ch_in=3, ch_out=64, kernel_size=3,stride=1, padding=1)
        self.maxpool10 = nn.MaxPool2d(kernel_size=2, stride=2,padding=0)
        self.convs11_1_=conv(ch_in=64, ch_out=64, kernel_size=3,stride=1, padding=1)
        self.convs11_2=conv(ch_in=64, ch_out=64, kernel_size=3,stride=1, padding=1)

        self.maxpool11 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convs12_1=conv(ch_in=64, ch_out=128, kernel_size=3,stride=1, padding=1)
        self.convs12_2=conv(ch_in=128, ch_out=128, kernel_size=3,stride=1, padding=1)

        self.maxpool12 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convs13=conv(ch_in=128, ch_out=256, kernel_size=3,stride=1, padding=1)

        # decoder S1
        self.upconvs13 = upconv(ch_in=256, ch_out=256)
        self.transConvs13 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvs12_2=upconv(ch_in=256, ch_out=128) # 128+128, 128
        self.upconvs12_1=upconv(ch_in=128, ch_out=128)
        self.transConvs12 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvs11_2=upconv(ch_in=128, ch_out=64) # 64+64, 64
        self.upconvs11_1=upconv(ch_in=64, ch_out=32)
        self.upconvs1_nc=upconv(ch_in=32, ch_out=5)

        # Encoder planet
        self.convsp1_1=conv(ch_in=4, ch_out=16, kernel_size=3,stride=1, padding=1)
        self.convsp1_2=conv(ch_in=16, ch_out=16, kernel_size=3,stride=1, padding=1)
        self.maxpoolp1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.maxpoolp1_1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.convsp2_1=conv(ch_in=16, ch_out=32, kernel_size=3,stride=1, padding=1)
        self.convsp2_2=conv(ch_in=32, ch_out=32, kernel_size=3,stride=1, padding=1)
        self.maxpoolp2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.convsp3_1=conv(ch_in=32, ch_out=64, kernel_size=3,stride=1, padding=1)
        self.convsp3_2=conv(ch_in=64, ch_out=64, kernel_size=3,stride=1, padding=1)

        self.maxpoolp3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convsp4_1=conv(ch_in=64, ch_out=128, kernel_size=3,stride=1, padding=1)
        self.convsp4_2=conv(ch_in=128, ch_out=128, kernel_size=3,stride=1, padding=1)

        self.maxpoolp4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.convsp5=conv(ch_in=128, ch_out=256, kernel_size=3,stride=1, padding=1)

        # Decoder planet
        self.upconvsp3 = upconv(ch_in=256, ch_out=256)
        self.transConvsp3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvsp2_2=upconv(ch_in=256, ch_out=128) # 128+128, 128
        self.upconvsp2_1=upconv(ch_in=128, ch_out=128)
        self.transConvsp2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.upconvsp1_2=upconv(ch_in=176, ch_out=64) # 112+64, 64
        self.upconvsp1_1=upconv(ch_in=64, ch_out=32)
        self.upconvsp_nc=upconv(ch_in=32, ch_out=5)

        self.final=nn.ConvTranspose2d(in_channels=15, out_channels=5, kernel_size=3,
                               stride=2, padding=1, output_padding=1)

        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, X):
        X['s2'][torch.isnan(X['s2'])]=0
        X['s1'][torch.isnan(X['s1'])]=0
        X['planet'][torch.isnan(X['planet'])]=0
        # S2
        # encoder
        layer02 = self.convs20(X['s2'].to('cuda'))
        layer12 = self.convs21_1_(self.maxpool20(layer02))
        layer22 = self.convs21_2(layer12)
        layer32 = self.convs22_1(self.maxpool21(layer22))
        layer42 = self.convs22_2(layer32)
        layer52 = self.convs23(self.maxpool22(layer42))
        # decoder
        x2 = self.upconvs23(layer52)
        x2 = self.transConvs23(x2)
        x2 = torch.cat([x2, layer42], dim=1)
        x2 = self.upconvs22_2(x2)
        x2 = self.upconvs22_1(x2)
        x2 = self.transConvs22(x2)
        x2 = torch.cat([x2, layer22], dim=1)
        x2 = self.upconvs21_2(x2)
        x2 = self.upconvs21_1(x2)
        x2f = self.upconvs2_nc(x2)

        # S1
        #Encoder
        layer01 = self.convs10(X['s1'].to('cuda'))
        layer11 = self.convs11_1_(self.maxpool10(layer01))
        layer21 = self.convs11_2(layer11)
        layer31 = self.convs12_1(self.maxpool11(layer21))
        layer41 = self.convs12_2(layer31)
        layer51 = self.convs13(self.maxpool12(layer41))
        # decoder
        x1 = self.upconvs13(layer51)
        x1 = self.transConvs13(x1)
        x1 = torch.cat([x1, layer41], dim=1)
        x1 = self.upconvs12_2(x1)
        x1 = self.upconvs12_1(x1)
        x1 = self.transConvs12(x1)
        x1 = torch.cat([x1, layer21], dim=1)
        x1 = self.upconvs11_2(x1)
        x1 = self.upconvs11_1(x1)
        x1f = self.upconvs1_nc(x1)

        # Planet
        # Encoder
        layerp0 = self.convsp1_1(X['planet'].to('cuda'))
        layerp1 = self.convsp1_2(layerp0)
        maxp1 = self.maxpoolp1(layerp1)
        maxp1_1 = self.maxpoolp1(maxp1)

        layerp2 = self.convsp2_1(maxp1)
        layerp3 = self.convsp2_2(layerp2)
        maxp2 = self.maxpoolp2(layerp3)

        layerp4 = self.convsp3_1(maxp2)
        layerp5 = self.convsp3_2(layerp4)

        layerp6 = self.convsp4_1(self.maxpoolp3(layerp5))
        layerp7 = self.convsp4_2(layerp6)

        layerp8 = self.convsp5(self.maxpoolp4(layerp7))

        # decoder
        x = self.upconvsp3(layerp8)
        x = self.transConvsp3(x)
        x = torch.cat([x, layerp7], dim=1)
        x = self.upconvsp2_2(x)
        x = self.upconvsp2_1(x)
        x = self.transConvsp2(x)
        x = torch.cat([x, layerp5, maxp1_1, maxp2], dim=1)
        x = self.upconvsp1_2(x)
        x = self.upconvsp1_1(x)
        xpf = self.upconvsp_nc(x)

        out = self.final(torch.cat([x1f, x2f, xpf], dim=1))

        return self.softmax(out)

In [None]:
# Model object
model_2d_unet = UNet_planet().to('cuda')

In [None]:
loss_weights = 1 - np.array([.85, .17, .56, .16, .11])
loss_weights = torch.tensor(loss_weights, dtype=torch.float32).cuda()

config = {
    "epochs"           : 100,
    "lr"               : 0.001,
    "label_smoothing"  : 0.2,
    "momentum"         : 0.9,
    "weight_decay"     : 0.0001,
    "loss_weights"     : loss_weights
}



#### Training

In [None]:
wandb.login(key="ed120be65ed3b503c10399eb93a51f7112e342dc") #API Key

In [None]:
run = wandb.init(
    name = "UNET_2D-Ghana", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # id ='vxx60mpj', ### Insert specific run id here if you want to resume a previous run
    # resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "Crop-type-segmentation", ### Project should be created in your wandb account
     ### Wandb Config for your run
)

In [None]:
import gc
gc.collect()

In [None]:
# metrics, train and validation functions
def reshapeForLoss(y):
    """ Reshapes labels or preds for loss fn.
    To get them to the correct shape, we permute:
      [batch x classes x rows x cols] --> [batch x rows x cols x classes]
      and then reshape to [N x classes], where N = batch*rows*cols
    """
    # [batch x classes x rows x cols] --> [batch x rows x cols x classes]
    y = y.permute(0, 2, 3, 1)
    # [batch x rows x cols x classes] --> [batch*rows*cols x classes]
    y = y.contiguous().view(-1, y.shape[3])
    return y

def mask_ce_loss(y_true, y_pred, weight_scale=1):

    y_true = reshapeForLoss(y_true)
    num_examples = torch.sum(y_true, dtype=torch.float32).cuda()
    y_pred = reshapeForLoss(y_pred)

    loss_mask = torch.sum(y_true, dim=1).type(torch.LongTensor)
    loss_mask_repeat = loss_mask.unsqueeze(1).repeat(1,y_pred.shape[1]).type(torch.FloatTensor)
    _, y_true = torch.max(y_true, dim=1)
    y_true = y_true * loss_mask
    y_pred_ = y_pred * loss_mask_repeat.cuda()


    loss_fn = nn.NLLLoss(weight = config["loss_weights"] ** weight_scale)

    total_loss = torch.sum(loss_fn(y_pred, y_true.cuda()))

    if num_examples == 0:
        print("WARNING: NUMBER OF EXAMPLES IS 0")

    else: return total_loss / num_examples


In [None]:
# loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"], weight=config["loss_weights"])
optimizer = torch.optim.SGD(model_2d_unet.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"])

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.00005)


In [None]:
# metrics, train and validation functions

# Metrics
def crop_segmentation_metrics(y_true, y_pred):
        y_true = reshapeForLoss(y_true.cpu())
        y_pred = reshapeForLoss(y_pred.cpu())

        loss_mask = torch.sum(y_true, dim=1).type(torch.LongTensor)

        _, y_true = torch.max(y_true, dim=1)
        _, y_pred = torch.max(y_pred, dim=1)

        y_true = y_true[loss_mask == 1]
        y_pred = y_pred[loss_mask == 1]

        assert (y_true.shape == y_pred.shape)
        y_true = y_true.int()
        y_pred = y_pred.int()
        f1 = f1_score(y_true, y_pred, average='macro')
        acc = accuracy_score(y_true, y_pred)
        cm = confusion_matrix(y_true, y_pred)

        return f1, acc, cm

# Train
def train_step(data_loader, optimizer, accuracy_fn):
  """Performs a training with model trying to learn on data-loader"""

  curr_lr = float(optimizer.param_groups[0]['lr'])
  model_2d_unet.train() # Put model into training mode

  train_loss, train_acc, train_f1=0,0,0
  # Add a loop to loop through the training batches
  for img,label in tqdm(data_loader):

    # 1. Forwad pass
    y_pred=model_2d_unet(img)

    # 2. Calculate loss and accuracy (per batch)
    label = label.long()
    label = torch.nn.functional.one_hot(label, num_classes=5).permute([0,3,1,2])


    loss=mask_ce_loss(label, y_pred)

    train_loss+=loss.item() #  accumulate training loss
    # y_pred_masked = torch.max(y_pred,dim=1)[1]
    train_acc += accuracy_fn(y_true=label, y_pred=y_pred)[1]
    train_f1+=accuracy_fn(y_true=label, y_pred=y_pred)[0]

    # 3. optimize zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()


  # Divide total train loss and acc by lenth of train dataloader
  train_loss/=len(data_loader)
  train_acc/=len(data_loader)
  train_f1/=len(data_loader)
  print(f"Train loss {train_loss: .5f}|Train acc : {train_acc:.5f} | Train f1: {train_f1:.5f} | lr: {curr_lr}\n")
  return train_loss, train_acc, train_f1, curr_lr

# Validation
def validation_step(data_loader, accuracy_fn):
  """Performs a tesing loop step on model going over data loader."""

  val_loss,val_acc, val_f1=0,0,0

  model_2d_unet.eval() # put the model in eval mode
  # turn on inference mode context manager
  with torch.inference_mode():
    for img,label in tqdm(data_loader):

        # 1. Forward pass (outputs raw logits)
      val_pred=model_2d_unet(img)

      # 2. Calculate the loss/acc
      label = label.long()
      label = torch.nn.functional.one_hot(label, num_classes=5).permute([0,3,1,2])
      val_loss+=mask_ce_loss(label, val_pred)

      # val_pred_masked = torch.max(val_pred,dim=1)[1]
      val_acc += accuracy_fn(y_true=label, y_pred=val_pred)[1]
      val_f1+=accuracy_fn(y_true=label, y_pred=val_pred)[0]

    # Adjust metrics and print out
    val_loss/=len(data_loader)
    val_acc/=len(data_loader)
    val_f1/=len(data_loader)
    print(f"Val loss: {val_loss:.5f} | Val acc: {val_acc:.5f} | Val f1: {val_f1:.5f}\n")

    return val_loss, val_acc, val_f1

In [None]:

# Let's train
epochs=config['epochs']

# Create an optimization and evluation using train_step() and val_step()
train_loss_list=[]
train_acc_list=[]
train_f1_list=[]

val_loss_list=[]
val_acc_list=[]
val_f1_list=[]


best_val_f1 = float('inf')

for epoch in range(epochs):

  print(f"Epoch: {epoch}/{epochs} \n----------------")


  train_loss, train_acc, train_f1, curr_lr = train_step(
                                                data_loader=train_loader,
                                                optimizer=optimizer,
                                                accuracy_fn=crop_segmentation_metrics
                                                )
  train_loss_list.append(train_loss)
  train_acc_list.append(train_acc)
  train_f1_list.append(train_f1)

  val_loss, val_acc, val_f1 = validation_step(
                                            data_loader=val_loader,
                                            accuracy_fn=crop_segmentation_metrics,
                                            )
  val_loss_list.append(val_loss)
  val_acc_list.append(val_acc)
  val_f1_list.append(val_f1)

  wandb.log({"train_loss": train_loss, 'train_f1': train_f1, 'train_acc': train_acc, 'validation_f1':val_f1,
               'validation_loss': val_loss, 'validation_acc': val_acc, "learning_Rate": curr_lr})

  if val_f1 < best_val_f1:
      best_val_loss = val_f1
      torch.save(model_2d_unet.state_dict(), '/content/models/best.pth')  # Save the best model

      print("Saving model")
      torch.save({'model_state_dict':model_2d_unet.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'scheduler_state_dict':scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'epoch': epoch}, '/content/models/best.pth')

      wandb.save('/content/models/best.pth')

  scheduler.step()

run.finish()
