# Big Problem

We dont want the datloader to read and load the depth while training to save time .

So if we ant to simulatneously want to train then we have to load depth even for the other images 

In [3]:
import os
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import h5py

from pprint import pprint
from torch.utils.data import DataLoader

import numpy as np
from torchmetrics import ConfusionMatrix


In [4]:
import dataloader_utils #Not using the above dataset but from the class


In [5]:
root = '../../learning_blenerproc/images_robocup'
# init train, val, test sets
train_dataset = dataloader_utils.SimpleRoboCupDataset(root, "train")
valid_dataset = dataloader_utils.SimpleRoboCupDataset(root, "valid")

# It is a good practice to check datasets don`t intersects with each other
assert set(train_dataset.filenames).isdisjoint(set(valid_dataset.filenames))

print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(valid_dataset)}")

n_cpu = os.cpu_count()
n_batch_size = 32
print (" CPU ", n_cpu)
train_dataloader = DataLoader(train_dataset, batch_size=n_batch_size, shuffle=True, num_workers=int(n_cpu/2))
valid_dataloader = DataLoader(valid_dataset, batch_size=n_batch_size, shuffle=False, num_workers=int(n_cpu/2))


Train size: 480
Valid size: 54
 CPU  16


In [None]:
class RoboCupModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn_without_background = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE, 
                                                              from_logits=True, 
                                                              ignore_index=0.0)
        self.loss_fn = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE, from_logits=True)
        self.n_classes = out_classes
        
        self.epoch = 0
        
        self.confmat = ConfusionMatrix(num_classes=out_classes, normalize='none')
        self.fusion_1d = torch.nn.Conv2d(in_channels=2*self.n_classes, out_channels=self.n_classes, kernel_size=1)




    def forward(self, image):
        # normalize image here
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage, optimizer_idx):
        
        # train model
        if optimizer_idx == 0:
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 255.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        if (stage == "train"):
            if self.epoch % 3 == 0:
                loss = self.loss_fn(logits_mask, mask.long())
            else:
                #loss = self.loss_fn(logits_mask, mask.long())
                loss = self.loss_fn_without_background(logits_mask, mask.long())
        else:
            #loss = self.loss_fn(logits_mask, mask.long())
            loss = self.loss_fn_without_background(logits_mask, mask.long())

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        #prob_mask = logits_mask.sigmoid()
        #pred_mask = (prob_mask > 0.5).float()
        prob_mask = logits_mask.log_softmax(dim=1).exp()
        pred_mask = prob_mask.argmax(dim=1, keepdim=True)
        #print ("prob mask ",prob_mask.shape)
        #print ("pred mask ",pred_mask.shape)
        #print (" mask ", mask.shape)
        #Confusion matrix calculation
        confusion_matrix = self.confmat(pred_mask.long().ravel(), mask.long().ravel())


        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="multiclass", 
                                               num_classes=self.n_classes)
        
        
        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
            "confusion_matrix":confusion_matrix,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        
         #confusion matrix sum
        self.cm = torch.sum(torch.stack([x["confusion_matrix"] for x in outputs]), dim=0)
        

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)
        self.epoch += 1

    def training_step(self, batch, batch_idx, optimizer_idx):
        return self.shared_step(batch, "train", optimizer_idx)            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx, optimizer_idx):
        return self.shared_step(batch, "valid", optimizer_idx)

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        optimizer_model=torch.optim.AdamW(self.model.parameters(), lr=0.0001, weight_decay=1e-5, amsgrad=True)
        scheduler_model = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-4, last_epoch=-1)
        optimizer_fusion1d=torch.optim.AdamW(self.fusion_1d.parameters(), lr=0.0001, weight_decay=1e-5, amsgrad=True)
        scheduler_fusion1d = CosineAnnealingWarmRestarts(optimizer_fusion1d, T_0=10, T_mult=2, 
                                                         eta_min=1e-4, last_epoch=-1)
        return {'optimizer': [optimizer_model, optimizer_fusion1d],
                'lr_scheduler':[scheduler_model, scheduler_fusion1d]}
     