In [36]:
import os
import json
import gc
from glob import glob
from tqdm import tqdm
from typing import List, Tuple, Optional, Any, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from evaluate import load

import segmentation_models_pytorch as smp
from torchmetrics.functional import dice
import wandb

from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from operator import itemgetter
from PIL import Image

from collections import namedtuple

import time
import copy
from collections import defaultdict

from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

## Config

In [37]:
class CFG:
    seed          = 101
    debug         = False
    exp_name      = 'cityscapes'
    comment       = 'segformer-b5-cityscapes-1024x1024-ep=5'
    model_name    = 'segformer'
    train_bs      = 1
    valid_bs      = 1
    image_size    = [1024, 1024]
    epochs        = 5
    lr            = 2e-3
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs) + 50
    T_0           = 25
    warmup_epochs = 0
    weight_decay  = 1e-6
    n_fold        = 5
    num_classes   = 19
    device        = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


## Reproducibility

In [38]:
def set_seed(seed: int = 42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('--- SEEDING DONE ---')

set_seed(CFG.seed)

--- SEEDING DONE ---


## Dataset

In [39]:
Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )


labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

In [40]:
class CityScapesDataset(data.Dataset):
    def __init__(self, image_paths: List[str], mask_paths: List[str], stage: str = "train", image_size: int = 1024) -> None:
        super(CityScapesDataset, self).__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_size = image_size

        if stage == "train":
            self.transforms = A.Compose([
                A.augmentations.Resize(height=self.image_size, width=self.image_size, interpolation=0),
                A.augmentations.Rotate(limit=90, p=0.5),
                A.augmentations.HorizontalFlip(p=0.5),
#                 A.augmentations.VerticalFlip(p=0.5),
                A.augmentations.transforms.ColorJitter(p=0.5),
                A.Normalize(),
                A.pytorch.ToTensorV2(),
            ])


        else:
            self.transforms = A.Compose([
                A.Resize(height=self.image_size, width=self.image_size, interpolation=0),
                A.Normalize(),
                ToTensorV2()
            ])

        self.target_train_ids = []
        self.target_ids = []
        self.target_colors = []
        self.target_labels = []

        for label in labels:
            if label.trainId != 255 and label.trainId != -1:
                self.target_colors.append(label.color)
                self.target_labels.append(label.name)
                self.target_train_ids.append(label.trainId)
                self.target_ids.append(label.id)

        self.cmap = dict(zip(self.target_train_ids, self.target_colors))
        self.idmap = dict(zip(self.target_ids, self.target_train_ids))


    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        image = np.array(Image.open(self.image_paths[idx]))
        mask = np.array(Image.open(self.mask_paths[idx]))

        transformed = self.transforms(image=image, mask=mask)

        image, mask = transformed['image'], transformed['mask']

        unique = pd.Series(mask.reshape(-1)).unique()

        for ele in unique:
            if ele not in self.idmap.keys():
                mask = mask.where(mask != ele, 255)

        unique = pd.Series(mask.reshape(-1)).unique()
        # print(unoq)
        for k, v in self.idmap.items():
            mask = mask.where(mask != k, v)

        new_mask = torch.stack([mask]*3, axis=-1)
        new_mask.shape

        for k, v in self.cmap.items():
            # print(k, v)
            new_mask = new_mask.where(new_mask != k, torch.LongTensor(v))

        return {
            'image': image.clone(),
            'mask': mask,
            'colored': new_mask,
            'instance': np.array(Image.open(self.mask_paths[idx]))
        }

## DataLoader

In [41]:
def prepare_dataloader():
    train_dataset = CityScapesDataset(
        image_paths=sorted(list(glob('./data/leftImg8bit_trainvaltest/leftImg8bit/train/*/*.png'))),
        mask_paths=sorted(list(glob('./data/gtFine_trainvaltest/gtFine/train/*/*_gtFine_labelIds.png'))),
        stage="train",
        image_size=CFG.image_size[0]
    )
    
    valid_dataset = CityScapesDataset(
        image_paths=sorted(list(glob('./data/leftImg8bit_trainvaltest/leftImg8bit/val/*/*.png'))),
        mask_paths=sorted(list(glob('./data/gtFine_trainvaltest/gtFine/val/*/*_gtFine_labelIds.png'))),
        stage="train",
        image_size=CFG.image_size[0]
    )
    
    train_loader = data.DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True, pin_memory=True)
    valid_loader = data.DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False, pin_memory=True) 
    
    return train_loader, valid_loader

In [42]:
train_loader, valid_loader = prepare_dataloader()

In [43]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x2fb3c33d0>

In [44]:
batch = next(iter(train_loader))
{k:v.shape for k, v in batch.items()}

{'image': torch.Size([1, 3, 1024, 1024]),
 'mask': torch.Size([1, 1024, 1024]),
 'colored': torch.Size([1, 1024, 1024, 3]),
 'instance': torch.Size([1, 1024, 2048])}

In [45]:
batch = next(iter(valid_loader))
{k:v.shape for k, v in batch.items()}

{'image': torch.Size([1, 3, 1024, 1024]),
 'mask': torch.Size([1, 1024, 1024]),
 'colored': torch.Size([1, 1024, 1024, 3]),
 'instance': torch.Size([1, 1024, 2048])}

In [46]:
gc.collect()

36553

## Model

In [47]:
class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel, self).__init__()
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
            ignore_mismatched_sizes=True,
        )

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        outputs = self.model(pixel_values=batch['image'])

        upsampled_logits = F.interpolate(
            outputs.logits,
            size=batch['mask'].shape[-2:],
            mode="bilinear",
            align_corners=False
        )

        return upsampled_logits

## Training Function

In [31]:
def train_one_epoch(model, optimizer, scheduler, loss_fn, loader):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc='Train ')
    for step, batch in pbar:
        images = batch['image'].to(CFG.device, dtype=torch.float32)
        mask = batch['mask'].to(CFG.device, dtype=torch.float32)
        
        batch_size = images.size(0)
        
        y_pred = model(images)
        loss = loss_fn(y_pred=y_pred, y_true=mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(
            train_loss=f'{epoch_loss:.3f}',
            lr=f'{current_lr:.3f}'
        )
        gc.collect()
    
    return epoch_loss

## Validation Function

In [32]:
@torch.no_grad()
def valid_one_epoch(model, loss_fn, loader):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc='Valid ')
    for step, batch in pbar:
        images = batch['images'].to(CFG.device, dtype=torch.float32)
        masks = batch['masks'].to(CFG.device, dtype=torch.float32)
        
        batch_size = images.size(0)
        
        y_pred = model(images)
        loss = loss_fn(y_pred=y_pred, y_true=masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        pbar.set_postfix(
            valid_loss=f'{epoch_loss:.3f}',
        )
        
        gc.collect()
    
    return epoch_loss

In [33]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice      = -np.inf
    best_epoch     = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CFG.device, epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, valid_loader, 
                                                 device=CFG.device, 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
    
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        # Log the metrics
        wandb.log({"Train Loss": train_loss, 
                   "Valid Loss": val_loss,
                   "Valid Dice": val_dice,
                   "Valid Jaccard": val_jaccard,
                   "LR":scheduler.get_last_lr()[0]})
        
        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        
        # deep copy the model
        if val_dice >= best_dice:
            print(f"{c_}Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            run.summary["Best Dice"]    = best_dice
            run.summary["Best Jaccard"] = best_jaccard
            run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            wandb.save(PATH)
            print(f"Model Saved{sr_}")
            
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch-{fold:02d}.bin"
        torch.save(model.state_dict(), PATH)
            
        print(); print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu")
model = SegmentationModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = None

In [35]:
wandb.init()

run_training(model, optimizer, scheduler, device, 5)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016751231933327896, max=1.0…

Epoch 1/5