<a href="https://colab.research.google.com/github/Diane10/ethicsdataset/blob/main/IDL_Gustave.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install visdom==0.1.7 wandb rasterio
!pip install einops
!pip install einsum



In [2]:
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import autograd
import torch.optim as optim

import os
from glob import glob
from datetime import datetime
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.model_selection import GroupShuffleSplit
from matplotlib import pyplot as plt
plt.switch_backend('agg')

import visdom
from torchvision.utils import save_image, make_grid
import datetime
import pickle
from tqdm import tqdm
import itertools
import wandb
import math
from einops import rearrange

import argparse
import pandas as pd
import time
import random
from datetime import datetime
import rasterio
from pathlib import Path

from PIL import Image
import torchvision.transforms as transforms
import torch.nn.utils.rnn as rnn
import json
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
from torchvision import models
import warnings
warnings.filterwarnings('ignore')

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cuda


In [4]:
# !gsutil ls gs://data_ctm/data/data/africa_crop_type_mapping/ghana/

In [5]:
# pd.read_csv('gs://data_ctm/data/data/africa_crop_type_mapping/ghana/.ipynb_checkpoints/list_eval_partition-checkpoint.csv')

In [6]:
# BAND STATS

BANDS = { 's1': { 'VV': 0, 'VH': 1, 'RATIO': 2},
          's2': { '10': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'RDED1': 3, 'RDED2': 4, 'RDED3': 5, 'NIR': 6, 'RDED4': 7, 'SWIR1': 8, 'SWIR2': 9},
                   '4': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'NIR': 3}},
          'planet': { '4': {'BLUE': 0, 'GREEN': 1, 'RED': 2, 'NIR': 3}}}

MEANS = { 's1': { 'ghana': torch.Tensor([-10.50, -17.24, 1.17]),
                  'southsudan': torch.Tensor([-9.02, -15.26, 1.15])},
          's2': { 'ghana': torch.Tensor([2620.00, 2519.89, 2630.31, 2739.81, 3225.22, 3562.64, 3356.57, 3788.05, 2915.40, 2102.65]),
                  'southsudan': torch.Tensor([2119.15, 2061.95, 2127.71, 2277.60, 2784.21, 3088.40, 2939.33, 3308.03, 2597.14, 1834.81])},
          'planet': { 'ghana': torch.Tensor([1264.81, 1255.25, 1271.10, 2033.22]),
                      'southsudan': torch.Tensor([1091.30, 1092.23, 1029.28, 2137.77])},
          's2_cldfltr': { 'ghana': torch.Tensor([1362.68, 1317.62, 1410.74, 1580.05, 2066.06, 2373.60, 2254.70, 2629.11, 2597.50, 1818.43]),
                  'southsudan': torch.Tensor([1137.58, 1127.62, 1173.28, 1341.70, 1877.70, 2180.27, 2072.11, 2427.68, 2308.98, 1544.26])} }

STDS = { 's1': { 'ghana': torch.Tensor([3.57, 4.86, 5.60]),
                 'southsudan': torch.Tensor([4.49, 6.68, 21.75])},
         's2': { 'ghana': torch.Tensor([2171.62, 2085.69, 2174.37, 2084.56, 2058.97, 2117.31, 1988.70, 2099.78, 1209.48, 918.19]),
                 'southsudan': torch.Tensor([2113.41, 2026.64, 2126.10, 2093.35, 2066.81, 2114.85, 2049.70, 2111.51, 1320.97, 1029.58])},
         'planet': { 'ghana': torch.Tensor([602.51, 598.66, 637.06, 966.27]),
                     'southsudan': torch.Tensor([526.06, 517.05, 543.74, 1022.14])},
         's2_cldfltr': { 'ghana': torch.Tensor([511.19, 495.87, 591.44, 590.27, 745.81, 882.05, 811.14, 959.09, 964.64, 809.53]),
                 'southsudan': torch.Tensor([548.64, 547.45, 660.28, 677.55, 896.28, 1066.91, 1006.01, 1173.19, 1167.74, 865.42])} }


CROPS = { 'ghana': ['groundnut', 'maize', 'rice', 'soya bean'],
          'southsudan': ['sorghum', 'maize', 'rice', 'groundnut']}

In [7]:
class CropTypeMappingDataset():

    def __init__(self, ):

        self.data_dir = '/content/data/africa_crop_type_mapping' # 'gs://data_ctm/data/africa_crop_type_mapping'

        self.split_dict = {'train': 0, 'val': 1, 'test': 2}
        self.split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'}

        # Extract splits
        split_df = pd.read_csv(os.path.join(self.data_dir, 'ghana', 'list_eval_partition.csv'))

        self.split_array = split_df['partition'].values


        # y_array stores idx ids corresponding to location. Actual y labels are
        # tensors that are loaded separately.
        self.y_array = torch.from_numpy(split_df['id'].values)

        self.y_size = (64, 64)

        self.metadata_fields = ['y']
        self.metadata_array = torch.from_numpy(split_df['id'].values)


    def __getitem__(self, idx):
        # Any transformations are handled by the SustainBenchSubset
        # since different subsets (e.g., train vs test) might have different transforms
        x = self.get_input(idx)
        y = self.get_label(idx)
        return x, y


    def get_input(self, idx):
        """
        Returns X for a given idx.
        """
        loc_id = f'{self.y_array[idx]:06d}'

        images = np.load(os.path.join(self.data_dir, 'ghana', 'npy', f'{"ghana"}_{loc_id}.npz'))
        dates_idx = self.get_dates(loc_id)

        s1 = images['s1']
        s2 = images['s2']
        planet = images['planet']

        s1 = torch.from_numpy(s1)
        s2 = torch.from_numpy(s2.astype(np.int32))
        planet = torch.from_numpy(planet.astype(np.int32))

        planet = planet.permute(3, 0, 1, 2)
        planet = transforms.CenterCrop(128)(planet)
        planet = planet.permute(1, 2, 3, 0)

        # Normalization
        s1 = self.normalization(s1, 's1')
        s2 = self.normalization(s2, 's2')
        planet = self.normalization(planet, 'planet')

        s1 = s1[:,:,:,dates_idx["s1_min"]:dates_idx["s1_max"]]
        s2 = s2[:,:,:,dates_idx["s2_min"]:dates_idx["s2_max"]]
        planet = planet[:,:,:,dates_idx["planet_min"]:dates_idx["planet_max"]]

        s1_d = random.randint(0, dates_idx["s1_max"] - dates_idx["s1_min"]-1)
        s2_d = random.randint(0, dates_idx["s2_max"] - dates_idx["s2_min"]-1)
        planet_d = random.randint(0, dates_idx["planet_max"] - dates_idx["planet_min"]-1)

        s1 = np.squeeze(s1[:,:,:,s1_d])
        s2 = np.squeeze(s2[:,:,:,s2_d])
        planet = np.squeeze(planet[:,:,:,planet_d])


        return {'s1': torch.tensor(s1, dtype=torch.float32), 's2': torch.tensor(s2, dtype=torch.float32), 'planet': torch.tensor(planet, dtype=torch.float32)}

    def get_label(self, idx):
        """
        Returns y for a given idx.
        """
        loc_id = f'{self.y_array[idx]:06d}'
        label = np.load(os.path.join(self.data_dir, 'ghana', 'truth', f'{"ghana"}_{loc_id}.npz'))['truth']
        label = torch.from_numpy(label)
        label[label>4]=0
        return label

    def get_dates(self, loc_id):
        """
        Converts json dates into tensor containing dates
        """
        s1_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 's1', f"s1_{'ghana'}_{loc_id}.json"), 'r').read())
        s1 = s1_json['dates']

        s1 =np.array([datetime.strptime(date, "%Y-%m-%d") for date in s1])
        s1_date_low = s1[s1 >= datetime.strptime('2016-06-01', "%Y-%m-%d")]
        s1_date_high = s1_date_low[s1_date_low<= datetime.strptime('2016-11-30', "%Y-%m-%d")]
        s1_idx_min = np.where(s1==s1_date_high[0])
        s1_idx_max = np.where(s1==s1_date_high[-1])
        s1_idx_min = s1_idx_min[0][0]
        s1_idx_max = s1_idx_max[0][0]


        s2_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 's2', f"s2_{'ghana'}_{loc_id}.json"), 'r').read())
        s2 = s2_json['dates']

        s2 =np.array([datetime.strptime(date, "%Y-%m-%d") for date in s2])
        s2_date_low = s2[s2 >= datetime.strptime('2016-07-01', "%Y-%m-%d")]
        s2_date_high = s2_date_low[s2_date_low<= datetime.strptime('2016-11-30', "%Y-%m-%d")]
        s2_idx_min = np.where(s2==s2_date_high[0])
        s2_idx_max = np.where(s2==s2_date_high[-1])
        s2_idx_min = s2_idx_min[0][0]
        s2_idx_max = s2_idx_max[0][0]

        planet_json = json.loads(open(os.path.join(self.data_dir, 'ghana', 'planet', f"planet_{'ghana'}_{loc_id}.json"), 'r').read())
        planet = planet_json['dates']

        planet =np.array([datetime.strptime(date, "%Y-%m-%d") for date in planet])
        planet_date_low = planet[planet >= datetime.strptime('2017-07-01', "%Y-%m-%d")]
        planet_date_high = planet_date_low[planet_date_low<= datetime.strptime('2017-11-30', "%Y-%m-%d")]
        planet_idx_min = np.where(planet==planet_date_high[0])
        planet_idx_max = np.where(planet==planet_date_high[-1])
        planet_idx_min = planet_idx_min[0][0]
        planet_idx_max = planet_idx_max[0][0]

        return {"s1_min":s1_idx_min, "s1_max":s1_idx_max, "s2_min":s2_idx_min,"s2_max":s2_idx_max,"planet_min":planet_idx_min,"planet_max":planet_idx_max}

    def normalization(self, grid, satellite):
        """ Normalization based on values defined in constants.py
        Args:
          grid - (tensor) grid to be normalized
          satellite - (str) describes source that grid is from ("s1" or "s2")
        Returns:
          grid - (tensor) a normalized version of the input grid
        """
        num_bands = grid.shape[0]
        means = MEANS[satellite]['ghana']
        stds = STDS[satellite]['ghana']
        grid = (grid-means[:num_bands].reshape(num_bands, 1, 1, 1))/stds[:num_bands].reshape(num_bands, 1, 1, 1)

        if satellite not in ['s1', 's2', 'planet']:
            raise ValueError("Incorrect normalization parameters")
        return grid

In [8]:
class SustainBenchSubset(CropTypeMappingDataset):
    def __init__(self, dataset, split, transform=None):
        """
        This acts like torch.utils.data.Subset, but on SustainBenchDatasets.
        We pass in transform explicitly because it can potentially vary at
        training vs. test time, if we're using data augmentation.

        """
        super().__init__()
        self.dataset = dataset
        self.transform = transform

        split_mask = self.split_array == self.split_dict[split]
        self.indices = np.where(split_mask)[0]

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.indices)

In [9]:
dataset = CropTypeMappingDataset()

train_dataset = SustainBenchSubset(dataset, 'train')
val_dataset = SustainBenchSubset(dataset, 'val')


In [10]:
train_loader = DataLoader(
                train_dataset,
                shuffle=True, # Shuffle training dataset
                sampler=None,
                batch_size=10)

val_loader = DataLoader(
                val_dataset,
                shuffle=False, # Do not shuffle eval datasets
                sampler=None,
                batch_size=10)

In [11]:
for x,y in train_loader:
  print(x['planet'].shape)
  break

torch.Size([10, 4, 128, 128])


##Model

In [12]:
class cyclicShift(nn.Module):
  def __init__(self, displacement):
    super().__init__()
    self.displacement = displacement

  def forward(self, x):
    return torch.roll(x, shifts=(self.displacement, self.displacement), dims = (1,2) )


class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    return self.fn(x,**kwargs) + x

class PreNorm(nn.Module):
  def __init__(self, dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(dim)
    self.fn = fn

  def forward(self, x, **kwargs):
    return self.norm(self.fn(x, **kwargs))

class FeedForward(nn.Module):
  def __init__(self,dim,hidden_dim):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, dim)
    )

  def forward(self, x):
    return self.net(x)

In [13]:
def create_mask(window_size, displacement, upper_lower,left_right):
  mask = torch.zeros(window_size**2, window_size**2)

  if upper_lower:
    mask[-displacement*window_size:,:-displacement*window_size]=float("-inf")
    mask[:-displacement*window_size,-displacement*window_size:]=float("-inf")

  if left_right:
    mask = rearrange(mask, "(h1 w1) (h2 w2) -> h1 w1 h2 w2", h1=window_size,h2=window_size)
    mask[:,-displacement:, :, :-displacement] = float("-inf")
    mask[:,:-displacement, :,-displacement:] = float("-inf")
    mask = rearrange(mask, "h1 w1 h2 w2 -> (h1 w1) (h2 w2)")

  return mask

In [14]:
class WindowAttension(nn.Module):
  def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
    super().__init__()
    inner_dim = head_dim * heads
    self.heads = heads
    self.scale = head_dim ** -0.5
    self.window_size = window_size
    self.relative_pos_embedding = relative_pos_embedding
    self.shifted = shifted

    if self.shifted:
      displacement = window_size //2
      self.cyclic_shift = cyclicShift(-displacement)
      self.cyclic_back_shift = cyclicShift(displacement)

      self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,\
                                                       upper_lower=True, left_right=False), requires_grad=False)

      self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,\
                                                       upper_lower=False, left_right=True), requires_grad=False)

    self.to_qkv = nn.Linear(dim, inner_dim*3,bias=False)
    self.pos_embedding = nn.Parameter(torch.randn(window_size**2,window_size**2))
    self.to_out = nn.Linear(inner_dim, dim)

  def forward(self,x):
    if self.shifted:
      x = self.cyclic_shift(x)

    b, n_h, n_w, _, h = *x.shape, self.heads
    qkv = self.to_qkv(x).chunk(3,dim=-1)

    nw_h = n_h // self.window_size
    nw_w = n_w // self.window_size

    q, k, v = map(lambda t: rearrange(t, "b (nw_h w_h) (nw_w w_w) (h d)->b h (nw_h nw_w) (w_h w_w) d",\
                                      h=h,w_h=self.window_size,w_w=self.window_size), qkv)
    # Do product similarity
    dots = torch.einsum("b h w i d, b h w j d->b h w i j", q, k) * self.scale
    dots +=self.pos_embedding
    if self.shifted:
      dots[:,:,-nw_w:] +=self.upper_lower_mask
      dots[:,:,nw_w-1::nw_w] += self.left_right_mask

    attn = dots.softmax(dim=-1)
    out = torch.einsum("b h w i j, b h w j d->b h w i d", attn, v)

    out = rearrange(out, "b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)",\
                    h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)

    out = self.to_out(out)
    if self.shifted:
       out = self.cyclic_back_shift(out)
    return out

In [15]:
class SwinBlock(nn.Module):
  def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding) :
    super().__init__()
    self.attention_block = Residual(PreNorm(dim, WindowAttension(dim=dim,heads=heads,head_dim=head_dim,\
                                                                 shifted=shifted, window_size=window_size,\
                                                                 relative_pos_embedding=relative_pos_embedding)))

    self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim,hidden_dim=mlp_dim)))

  def forward(self, x):
    x = self.attention_block(x)
    x = self.mlp_block(x)
    return x

In [16]:
class PatchExpanding(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
    super().__init__()
    self.patch_expand = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                           stride=stride, padding=padding, output_padding=output_padding)

  def forward(self,x):
    x = self.patch_expand(x).permute(0,2,3,1)
    return x

class PatchMerging_Conv(nn.Module):
  def __init__(self, in_channels, out_channels, downscaling_factor):
    super().__init__()
    self.patch_merge = nn.Conv2d(in_channels, out_channels, kernel_size=downscaling_factor,stride=downscaling_factor,padding=0)

  def forward(self,x):
    x = self.patch_merge(x).permute(0,2,3,1)
    return x

In [17]:
class StageModule(nn.Module):
  def __init__(self,in_channels, hidden_dimension, layers, scaling_factor,num_heads,head_dim, window_size, relative_pos_embedding,
               PatchMerging=True, stride=2, padding=1, output_padding=1):
    super().__init__()
    assert layers % 2 == 0, "Stage layers need to be divisible by 2 for regular and shftrd block"
    if PatchMerging:
      self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension,downscaling_factor=scaling_factor)
    else:
      self.patch_partition = PatchExpanding(in_channels=in_channels , out_channels=hidden_dimension, kernel_size=scaling_factor, stride=stride,
                                            padding=padding, output_padding=output_padding)

    self.layers = nn.ModuleList([])
    for _ in range(layers//2):
      self.layers.append(nn.ModuleList([
          SwinBlock(dim=hidden_dimension,heads=num_heads, head_dim = head_dim, mlp_dim=hidden_dimension *4, shifted =False,\
                    window_size = window_size, relative_pos_embedding=relative_pos_embedding),

          SwinBlock(dim=hidden_dimension,heads=num_heads, head_dim = head_dim, mlp_dim=hidden_dimension *4, shifted =True,\
                    window_size = window_size, relative_pos_embedding=relative_pos_embedding)
      ]))

  def forward(self, x):
    x = self.patch_partition(x)

    for regular_block, shifted_block in self.layers:
      x = regular_block (x)
      x = shifted_block (x)
    return x.permute(0,3,1,2)

In [18]:
class SwinTransformerUnet(nn.Module):
  def __init__(self,*,hidden_dim, layers, channels=3, num_classes=5, heads=(3,6,12,24), head_dim=32,window_size=7, \
               downscaling_factors=(4,2,2,2), scaling_factor=(3,3,3,4), relative_pos_embedding=True):
    super().__init__()
    # Encoder
    self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim,layers=layers[0],\
                              scaling_factor=downscaling_factors[0],num_heads=heads[0],head_dim=head_dim,\
                              window_size=window_size, relative_pos_embedding=relative_pos_embedding)

    self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim*2,layers=layers[1],\
                              scaling_factor=downscaling_factors[1],num_heads=heads[1],head_dim=head_dim,\
                              window_size=window_size, relative_pos_embedding=relative_pos_embedding
                              )
    self.stage3 = StageModule(in_channels=hidden_dim*2, hidden_dimension=hidden_dim*4,layers=layers[2],\
                              scaling_factor=downscaling_factors[2],num_heads=heads[2],head_dim=head_dim,\
                              window_size=window_size, relative_pos_embedding=relative_pos_embedding
                              )
    self.stage4 = StageModule(in_channels=hidden_dim*4, hidden_dimension=hidden_dim*8,layers=layers[3],\
                              scaling_factor=downscaling_factors[3],num_heads=heads[3],head_dim=head_dim,\
                              window_size=window_size, relative_pos_embedding=relative_pos_embedding
                              )

    #Decoder
    self.stage11 = StageModule(in_channels=hidden_dim*8, hidden_dimension=hidden_dim*4,layers=layers[0],\
                              scaling_factor=scaling_factor[0],num_heads=heads[3],head_dim=head_dim,\
                              window_size=window_size, PatchMerging=False, relative_pos_embedding=relative_pos_embedding)

    self.stage22 = StageModule(in_channels=hidden_dim*4+hidden_dim*4, hidden_dimension=hidden_dim*2,layers=layers[1],\
                              scaling_factor=scaling_factor[1],num_heads=heads[2],head_dim=head_dim,\
                              window_size=window_size, PatchMerging=False, relative_pos_embedding=relative_pos_embedding
                              )
    self.stage33 = StageModule(in_channels=hidden_dim*2+hidden_dim*2, hidden_dimension=hidden_dim,layers=layers[2],\
                              scaling_factor=scaling_factor[2],num_heads=heads[1],head_dim=head_dim,\
                              window_size=window_size, PatchMerging=False, relative_pos_embedding=relative_pos_embedding
                              )
    self.stage44 = StageModule(in_channels=hidden_dim+hidden_dim, hidden_dimension=hidden_dim,layers=layers[3],\
                              scaling_factor=scaling_factor[3],num_heads=heads[0],head_dim=head_dim,\
                              stride=4, padding=1, output_padding=2, window_size=window_size, PatchMerging=False,\
                              relative_pos_embedding=relative_pos_embedding
                              )

  def forward(self, img):
    #Encoder
    stage1 = self.stage1(img)
    stage2 = self.stage2(stage1)
    stage3 = self.stage3(stage2)
    stage4 = self.stage4(stage3)

    #Decoder
    x = self.stage11(stage4)
    x=torch.cat([x, stage3], dim=1)
    x = self.stage22(x)
    x=torch.cat([x, stage2], dim=1)
    x = self.stage33(x)
    x=torch.cat([x, stage1], dim=1)
    x = self.stage44(x)
    return x

In [19]:
class SwinUnet(nn.Module):
  def __init__(self,):
    super().__init__()
    self.planet = SwinTransformerUnet(hidden_dim = 96, layers=(2,6,2,2),heads=(3,6,12,24),channels=4,num_classes=5,head_dim=32,\
             window_size=2,relative_pos_embedding=True)

    self.final_planet = nn.Conv2d(in_channels=96, out_channels=96, stride= 2, kernel_size=2 ,padding= 0)

    self.s1 = SwinTransformerUnet(hidden_dim = 96, layers=(2,6,2,2),heads=(3,6,12,24),channels=3,num_classes=5,head_dim=32,\
             window_size=2,relative_pos_embedding=True)

    self.s2 = SwinTransformerUnet(hidden_dim = 96, layers=(2,6,2,2),heads=(3,6,12,24),channels=10,num_classes=5,head_dim=32,\
             window_size=2,relative_pos_embedding=True)

    self.final=nn.Conv2d(in_channels=288, out_channels=5, kernel_size=3,
                               stride=1, padding=1)
    self.softmax = nn.LogSoftmax(dim=1)

  def forward(self,X):


    planet = self.final_planet(self.planet(X['planet'].to(DEVICE)))
    s1 = self.s1(X['s1'].to(DEVICE))
    s2 = self.s2(X['s2'].to(DEVICE))
    out = self.final(torch.cat([s1, s2, planet], dim=1))
    return self.softmax(out)

In [20]:
# Model object
model_swin = SwinUnet().to(DEVICE)


In [21]:
loss_weights = 1 - np.array([.85, .17, .56, .16, .11])
loss_weights = torch.tensor(loss_weights, dtype=torch.float32).cuda()

config = {
    "epochs"           : 100,
    "lr"               : 0.001,
    "label_smoothing"  : 0.2,
    "momentum"         : 0.9,
    "weight_decay"     : 0.0001,
    "loss_weights"     : loss_weights
}



#### Training

In [22]:
wandb.login(key="ed120be65ed3b503c10399eb93a51f7112e342dc") #API Key

[34m[1mwandb[0m: Currently logged in as: [33mbwirayesu[0m ([33mwn[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [23]:
run = wandb.init(
    name = "SwinT-Ghana-indivTime-1", ## Wandb creates random run names if you skip this field
    # reinit = True, ### Allows reinitalizing runs when you re-run this cell
    id ='u4ujzh2v', ### Insert specific run id here if you want to resume a previous run
    resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "Crop-type-segmentation", ### Project should be created in your wandb account
     ### Wandb Config for your run
)

In [24]:
import gc
gc.collect()

6717

In [25]:
# metrics, train and validation functions
def reshapeForLoss(y):
    """ Reshapes labels or preds for loss fn.
    To get them to the correct shape, we permute:
      [batch x classes x rows x cols] --> [batch x rows x cols x classes]
      and then reshape to [N x classes], where N = batch*rows*cols
    """
    # [batch x classes x rows x cols] --> [batch x rows x cols x classes]
    y = y.permute(0, 2, 3, 1)
    # [batch x rows x cols x classes] --> [batch*rows*cols x classes]
    y = y.contiguous().view(-1, y.shape[3])
    return y

In [26]:
def mask_ce_loss(y_true, y_pred, weight_scale=1):

    y_true = reshapeForLoss(y_true)
    num_examples = torch.sum(y_true, dtype=torch.float32).cuda()
    y_pred = reshapeForLoss(y_pred)

    loss_mask = torch.sum(y_true, dim=1).type(torch.LongTensor)
    loss_mask_repeat = loss_mask.unsqueeze(1).repeat(1,y_pred.shape[1]).type(torch.FloatTensor)
    _, y_true = torch.max(y_true, dim=1)
    y_true = y_true * loss_mask
    y_pred_ = y_pred * loss_mask_repeat.cuda()


    loss_fn = nn.NLLLoss(weight = config["loss_weights"] ** weight_scale)

    total_loss = torch.sum(loss_fn(y_pred, y_true.cuda()))

    if num_examples == 0:
        print("WARNING: NUMBER OF EXAMPLES IS 0")

    else: return total_loss / num_examples

In [27]:
# loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"], weight=config["loss_weights"])
optimizer = torch.optim.SGD(model_swin.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"])

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.00005)


In [28]:


# Metrics
def crop_segmentation_metrics(y_true, y_pred):
        y_true = reshapeForLoss(y_true.cpu())
        y_pred = reshapeForLoss(y_pred.cpu())

        loss_mask = torch.sum(y_true, dim=1).type(torch.LongTensor)

        _, y_true = torch.max(y_true, dim=1)
        _, y_pred = torch.max(y_pred, dim=1)

        y_true = y_true[loss_mask == 1]
        y_pred = y_pred[loss_mask == 1]

        assert (y_true.shape == y_pred.shape)
        y_true = y_true.int()
        y_pred = y_pred.int()
        f1 = f1_score(y_true, y_pred, average='macro')
        acc = accuracy_score(y_true, y_pred)
        cm = confusion_matrix(y_true, y_pred)

        return f1, acc, cm

# Train
def train_step(data_loader, optimizer, accuracy_fn):
  """Performs a training with model trying to learn on data-loader"""

  curr_lr = float(optimizer.param_groups[0]['lr'])
  model_swin.train() # Put model into training mode

  train_loss, train_acc, train_f1=0,0,0
  # Add a loop to loop through the training batches
  for img,label in tqdm(data_loader):

    # 1. Forwad pass
    y_pred=model_swin(img)

    # 2. Calculate loss and accuracy (per batch)
    label = label.long()
    label = torch.nn.functional.one_hot(label, num_classes=5).permute([0,3,1,2])
    # label = label.to(DEVICE)


    loss=mask_ce_loss(label, y_pred)

    train_loss+=loss.item() #  accumulate training loss
    # y_pred_masked = torch.max(y_pred,dim=1)[1]
    train_acc += accuracy_fn(y_true=label, y_pred=y_pred)[1]
    train_f1+=accuracy_fn(y_true=label, y_pred=y_pred)[0]

    # print(torch.isnan(img['s1']).sum())
    # print(torch.isnan(img['s2']).sum())
    # print(torch.isnan(img['planet']).sum())
    # print(loss.item())

    # 3. optimize zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()


  # Divide total train loss and acc by lenth of train dataloader
  train_loss/=len(data_loader)
  train_acc/=len(data_loader)
  train_f1/=len(data_loader)
  print(f"Train loss {train_loss: .5f}|Train acc : {train_acc:.4f} | Train f1: {train_f1:.4f} | lr: {curr_lr}\n")
  return train_loss, train_acc, train_f1, curr_lr

# Validation
def validation_step(data_loader, accuracy_fn):
  """Performs a tesing loop step on model going over data loader."""

  val_loss,val_acc, val_f1=0,0,0

  model_swin.eval() # put the model in eval mode
  # turn on inference mode context manager
  with torch.inference_mode():
    for img,label in tqdm(data_loader):

        # 1. Forward pass (outputs raw logits)
      val_pred=model_swin(img)

      # 2. Calculate the loss/acc
      label = label.long()
      label = torch.nn.functional.one_hot(label, num_classes=5).permute([0,3,1,2])
      # label = label.to(DEVICE)
      val_loss+=mask_ce_loss(label, val_pred)

      # val_pred_masked = torch.max(val_pred,dim=1)[1]
      val_acc += accuracy_fn(y_true=label, y_pred=val_pred)[1]
      val_f1+=accuracy_fn(y_true=label, y_pred=val_pred)[0]

    # Adjust metrics and print out
    val_loss/=len(data_loader)
    val_acc/=len(data_loader)
    val_f1/=len(data_loader)
    print(f"Val loss: {val_loss:.5f} | Val acc: {val_acc:.4f} | Val f1: {val_f1:.4f}\n")

    return val_loss, val_acc, val_f1

In [None]:
# Let's train
epochs=config['epochs']

# Create an optimization and evluation using train_step() and val_step()
train_loss_list=[]
train_acc_list=[]
train_f1_list=[]

val_loss_list=[]
val_acc_list=[]
val_f1_list=[]


best_val_f1 = float('inf')

for epoch in range(epochs):

  print(f"Epoch: {epoch}/{epochs} \n----------------")


  train_loss, train_acc, train_f1, curr_lr = train_step(
                                                data_loader=train_loader,
                                                optimizer=optimizer,
                                                accuracy_fn=crop_segmentation_metrics
                                                )
  train_loss_list.append(train_loss)
  train_acc_list.append(train_acc)
  train_f1_list.append(train_f1)

  val_loss, val_acc, val_f1 = validation_step(
                                            data_loader=val_loader,
                                            accuracy_fn=crop_segmentation_metrics,
                                            )
  val_loss_list.append(val_loss)
  val_acc_list.append(val_acc)
  val_f1_list.append(val_f1)

  wandb.log({"train_loss": train_loss, 'train_f1': train_f1, 'train_acc': train_acc, 'validation_f1':val_f1,
               'validation_loss': val_loss, 'validation_acc': val_acc, "learning_Rate": curr_lr})

  if val_f1 < best_val_f1:
      best_val_loss = val_f1
      torch.save(model_swin.state_dict(), './models/best.pth')  # Save the best model

      print("Saving model")
      torch.save({'model_state_dict':model_swin.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'scheduler_state_dict':scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'epoch': epoch}, './models/best.pth')

      wandb.save('./models/best.pth')

  scheduler.step()

run.finish()


Epoch: 0/100 
----------------


100%|██████████| 226/226 [21:31<00:00,  5.72s/it]


Train loss  0.00004|Train acc : 0.23 | Train f1: 0.09 | lr: 0.001



100%|██████████| 30/30 [02:46<00:00,  5.54s/it]


Val loss: 0.00004 | Val acc: 0.26 | Val f1: 0.10

Saving model
Epoch: 1/100 
----------------


100%|██████████| 226/226 [21:08<00:00,  5.61s/it]


Train loss  0.00004|Train acc : 0.29 | Train f1: 0.11 | lr: 0.0009997656161737224



100%|██████████| 30/30 [02:45<00:00,  5.51s/it]


Val loss: 0.00004 | Val acc: 0.32 | Val f1: 0.11

Saving model
Epoch: 2/100 
----------------


100%|██████████| 226/226 [21:04<00:00,  5.60s/it]


Train loss  0.00004|Train acc : 0.35 | Train f1: 0.12 | lr: 0.000999062696003429



100%|██████████| 30/30 [02:45<00:00,  5.53s/it]


Val loss: 0.00004 | Val acc: 0.38 | Val f1: 0.13

Saving model
Epoch: 3/100 
----------------


100%|██████████| 226/226 [21:10<00:00,  5.62s/it]


Train loss  0.00004|Train acc : 0.41 | Train f1: 0.13 | lr: 0.0009978919331864629



100%|██████████| 30/30 [02:45<00:00,  5.50s/it]


Val loss: 0.00003 | Val acc: 0.44 | Val f1: 0.14

Saving model
Epoch: 4/100 
----------------


100%|██████████| 226/226 [21:09<00:00,  5.62s/it]


Train loss  0.00003|Train acc : 0.47 | Train f1: 0.14 | lr: 0.000996254483124377



100%|██████████| 30/30 [02:45<00:00,  5.53s/it]


Val loss: 0.00003 | Val acc: 0.50 | Val f1: 0.15

Saving model
Epoch: 5/100 
----------------


100%|██████████| 226/226 [21:07<00:00,  5.61s/it]


Train loss  0.00003|Train acc : 0.52 | Train f1: 0.15 | lr: 0.0009941519617826901



100%|██████████| 30/30 [02:46<00:00,  5.57s/it]


Val loss: 0.00003 | Val acc: 0.55 | Val f1: 0.16

Saving model
Epoch: 6/100 
----------------


100%|██████████| 226/226 [21:14<00:00,  5.64s/it]


Train loss  0.00003|Train acc : 0.57 | Train f1: 0.16 | lr: 0.0009915864440961269



100%|██████████| 30/30 [02:45<00:00,  5.52s/it]


Val loss: 0.00003 | Val acc: 0.60 | Val f1: 0.17

Saving model
Epoch: 7/100 
----------------


100%|██████████| 226/226 [21:14<00:00,  5.64s/it]


Train loss  0.00003|Train acc : 0.62 | Train f1: 0.17 | lr: 0.0009885604619209046



100%|██████████| 30/30 [02:45<00:00,  5.53s/it]


Val loss: 0.00003 | Val acc: 0.64 | Val f1: 0.17

Saving model
Epoch: 8/100 
----------------


100%|██████████| 226/226 [21:11<00:00,  5.62s/it]


Train loss  0.00003|Train acc : 0.66 | Train f1: 0.18 | lr: 0.0009850770015360992



100%|██████████| 30/30 [02:48<00:00,  5.60s/it]


Val loss: 0.00003 | Val acc: 0.68 | Val f1: 0.18

Saving model
Epoch: 9/100 
----------------


100%|██████████| 226/226 [21:11<00:00,  5.63s/it]


Train loss  0.00003|Train acc : 0.70 | Train f1: 0.18 | lr: 0.0009811395006965474



100%|██████████| 30/30 [02:45<00:00,  5.51s/it]


Val loss: 0.00003 | Val acc: 0.72 | Val f1: 0.18

Saving model
Epoch: 10/100 
----------------


100%|██████████| 226/226 [21:12<00:00,  5.63s/it]


Train loss  0.00003|Train acc : 0.73 | Train f1: 0.18 | lr: 0.0009767518452401974



100%|██████████| 30/30 [02:45<00:00,  5.53s/it]


Val loss: 0.00003 | Val acc: 0.75 | Val f1: 0.18

Saving model
Epoch: 11/100 
----------------


100%|██████████| 226/226 [21:10<00:00,  5.62s/it]


Train loss  0.00003|Train acc : 0.76 | Train f1: 0.19 | lr: 0.0009719183652532566



100%|██████████| 30/30 [02:44<00:00,  5.50s/it]


Val loss: 0.00003 | Val acc: 0.78 | Val f1: 0.19

Saving model
Epoch: 12/100 
----------------


100%|██████████| 226/226 [21:10<00:00,  5.62s/it]


Train loss  0.00003|Train acc : 0.79 | Train f1: 0.19 | lr: 0.0009666438307969189



100%|██████████| 30/30 [02:45<00:00,  5.50s/it]


Val loss: 0.00002 | Val acc: 0.80 | Val f1: 0.19

Saving model
Epoch: 13/100 
----------------


100%|██████████| 226/226 [21:03<00:00,  5.59s/it]


Train loss  0.00002|Train acc : 0.81 | Train f1: 0.19 | lr: 0.0009609334471998905



100%|██████████| 30/30 [02:44<00:00,  5.49s/it]


Val loss: 0.00002 | Val acc: 0.82 | Val f1: 0.19

Saving model
Epoch: 14/100 
----------------


100%|██████████| 226/226 [21:05<00:00,  5.60s/it]


Train loss  0.00002|Train acc : 0.83 | Train f1: 0.19 | lr: 0.0009547928499213589



100%|██████████| 30/30 [02:45<00:00,  5.51s/it]


Val loss: 0.00002 | Val acc: 0.83 | Val f1: 0.19

Saving model
Epoch: 15/100 
----------------


100%|██████████| 226/226 [21:07<00:00,  5.61s/it]


Train loss  0.00002|Train acc : 0.84 | Train f1: 0.19 | lr: 0.0009482280989894743



100%|██████████| 30/30 [02:44<00:00,  5.49s/it]


Val loss: 0.00002 | Val acc: 0.85 | Val f1: 0.20

Saving model
Epoch: 16/100 
----------------


100%|██████████| 226/226 [21:04<00:00,  5.60s/it]


Train loss  0.00002|Train acc : 0.85 | Train f1: 0.20 | lr: 0.0009412456730208348



100%|██████████| 30/30 [02:35<00:00,  5.17s/it]


Val loss: 0.00002 | Val acc: 0.86 | Val f1: 0.20

Saving model
Epoch: 17/100 
----------------


 39%|███▉      | 88/226 [08:00<12:43,  5.54s/it]