In [7]:
#!pip install -r /kaggle/input/pylit-wandb-smp-requirements/requirements.txt -q
!pip install segmentation_models_pytorch -q
!pip install icecream -q


In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import transforms
import matplotlib.pyplot as plt
from datetime import time 
import torch
from torch import nn
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import random
import metrics
import numpy as np
import cv2
import glob
import matplotlib.pyplot as plt
from icecream import ic

from idd_lite_helpers.idd_lite_helpers import IDD_Main_Dataset
from idd_lite_helpers.idd_lite_helpers import IDDRoadSegmentationDatamodule as dmidd

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import typing 
import os
import math
from datetime import datetime

# Seed random generator for repeatibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


In [None]:
torch.set_float32_matmul_precision('medium')

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdayaalex[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def decode_segmap(image, threshold=0.5):#changing single channel to 3 channel
    
    #print(image)#RGB
    image = image>threshold
    #print(image.shape)
    Background_scene = [255,255,255]
    Road = [51, 153, 255]

    label_colours = np.array([Road,Background_scene]).astype(np.uint8)
    
    #print(label_colours.shape)
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    
    for l in range(2):
        r[image == l] = label_colours[l, 0]
        g[image == l] = label_colours[l, 1]
        b[image == l] = label_colours[l, 2]

    rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
    rgb[:, :, 0] = r
    rgb[:, :, 1] = g
    rgb[:, :, 2] = b
    return rgb


In [7]:
# def get_wts(loader):

#     hist = torch.zeros(2)
#     for batch in tqdm(loader):
#         _, mask = batch
#         #print(mask.shape)
#         mask = mask.squeeze(0).float()
#         #print(mask.shape)

#         hist += torch.histc(mask, 2, 0,1 )

#     norm_hist = hist/torch.sum(hist)

#     class_wts = torch.ones(2)
#     for idx in range(2):
#         if hist[idx]<1 :
#             class_wts[idx] = 0
#         else:
#             class_wts[idx] = 1/torch.log(1.02 + norm_hist[idx])

#     return class_wts

# dm = dmidd(1,224)
# dm.setup()
# tr_weights = get_wts(dm.train_dataloader())
# vl_weights = get_wts(dm.val_dataloader())
# print('class wts of train ', tr_weights)
# print('class wts of val ', vl_weights)

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MCC_Loss(nn.Module):
    """
    Calculates the proposed Matthews Correlation Coefficient-based loss.
    Args:
        eps (float): Small epsilon to prevent division by zero.
    """
    def __init__(self, eps=1e-10):
        super(MCC_Loss, self).__init__()
        self.eps = eps

    def forward(self, logits, targets):
        """
        Applies sigmoid to logits and computes MCC loss.
        """
        inputs = torch.sigmoid(logits)  # Convert logits to probabilities
        tp = torch.sum(inputs * targets)
        tn = torch.sum((1 - inputs) * (1 - targets))
        fp = torch.sum(inputs * (1 - targets))
        fn = torch.sum((1 - inputs) * targets)

        numerator = tp * tn - fp * fn
        denominator = torch.sqrt(
            (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + self.eps
        )

        mcc = numerator / denominator
        return 1 - mcc  # Minimize loss = 1 - MCC during training

In [12]:
class BinarySegmentationForIdd(pl.LightningModule):
    def __init__(self,
                 model_name :str = 'unet',
                 encoder_name : str = 'efficientnet-b2',
                 encoder_weights :str = 'imagenet',
                 lr_e : float = 1e-1,
                 lr_d : float = 1e-3,
                    ):
        super().__init__()
        self.save_hyperparameters()
        self.tp, self.fp, self.fn, self.tn = 0,0,0,0
        self.Dice = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits = True)
        self.softbce = smp.losses.SoftBCEWithLogitsLoss()
        self.focal = smp.losses.FocalLoss(smp.losses.BINARY_MODE, alpha=None, gamma=5.0, ignore_index=None, reduction='mean', normalized=False, reduced_threshold=None)
        #self.mcc = smp.losses.MCC_Loss()
        self.lovasz = smp.losses.LovaszLoss(smp.losses.BINARY_MODE, per_image=False, ignore_index=None, from_logits=True)
        
        self.name = model_name
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights
        self.maxmiou = 1e-4
        self.start_time = 0
        if self.name == 'unetplusplus':
            self.model = smp.UnetPlusPlus(
                        encoder_name = self.encoder_name,
                        encoder_weights = self.encoder_weights,
                        decoder_attention_type = None,
                        in_channels = 3,
                        classes = 1
                        )
        elif self.name == 'unet':
            self.model = smp.Unet(
                        encoder_name = self.encoder_name,
                        encoder_weights = self.encoder_weights,
                        decoder_attention_type = 'scse',
                        in_channels = 3,
                        classes = 1
                        )
        elif self.name == 'pan':
            self.model = smp.PAN(
                        encoder_name = self.encoder_name,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1
                        )
        elif self.name == 'pspnet':
            self.model = smp.PSPNet(
                        encoder_name = self.encoder_name,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1
                        ) 
        elif self.name =='fpn':
            self.model = smp.FPN(
                        encoder_name = self.encoder_name, 
                        encoder_depth = 5,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1
            )
        elif self.name =='manet':
            self.model = smp.MAnet(
                        encoder_name = self.encoder_name,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1)
        elif self.name =='linknet':
            self.model = smp.Linknet(
                        encoder_name = self.encoder_name, 
                        encoder_depth = 5,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1
            )
        elif self.name =='deeplabplus':
             self.model = smp.DeepLabV3Plus(
                        encoder_name = self.encoder_name, 
                        encoder_depth = 5,
                        encoder_weights = self.encoder_weights,
                        in_channels = 3,
                        classes = 1
            )
    def forward(self,x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        outputs = self(imgs)
        train_loss = self.combined_losses(outputs, masks)
        self.log('train/train_loss', train_loss, on_step = True, on_epoch = True)
        return train_loss

    def on_validation_epoch_start(self): 
        self.val_step_outputs = []

    def validation_step(self,batch, batch_idx):
        imgs, masks = batch
        outputs = self(imgs)#batch,channel,height, width
        self.val_step_outputs.append(torch.sigmoid(outputs))
        val_loss = self.combined_losses(outputs, masks)
        self.log('val/val_loss', val_loss, on_step=False, on_epoch=True)

        this_tp, this_fp, this_fn, this_tn = metrics.get_stats(outputs.squeeze(), masks, mode = 'binary', threshold=0.5)

        self.tp += this_tp
        self.fp += this_fp
        self.fn += this_fn
        self.tn += this_tn

        return val_loss

    def combined_losses(self, outputs, masks):
        if masks.dtype == torch.uint8:
            masks = masks.float()

        loss1 = self.Dice(outputs, masks)
        loss2 = self.softbce(outputs.squeeze(1), masks)
#         loss3 = self.focal(outputs, masks)
#         loss4 = self.mcc(outputs, masks)
        loss5 = self.lovasz(outputs, masks)
        return loss1 + loss2
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam([
            {'params': self.model.encoder.parameters(), 'lr': self.hparams.lr_e},
            {'params': self.model.decoder.parameters(), 'lr': self.hparams.lr_d},
            ])
        scheduler = StepLR(optimizer, step_size=30, gamma=0.1, verbose=True)
        return {'optimizer':optimizer,'lr_scheduler':{'scheduler': scheduler, 'monitor': 'val_loss'}}
    
    def on_validation_epoch_end(self):
        val_miou = metrics.iou_score(sum(self.tp), sum(self.fp), sum(self.fn), sum(self.tn), reduction = 'micro')
        self.log('val/val_accuracy', val_miou)

        if val_miou> self.maxmiou:
            self.maxmiou = val_miou
            checkpoint = {
                'epochs': self.current_epoch,
                'state_dict': self.state_dict(),
                'miou': self.maxmiou,
                #to do add optimizer state dict if using lr scheduler 
            }
            torch.save(checkpoint, f'./{self.name}_{self.encoder_name}_accuracy{self.maxmiou:.4f}.pth')
            ckpt_artifact = wandb.Artifact(
                                f'{self.name}_artifact_ckpt', type = 'model'    
                                )
            ckpt_artifact.add_file(f'./{self.name}_{self.encoder_name}_accuracy{self.maxmiou:.4f}.pth')
            self.logger.experiment.log_artifact(ckpt_artifact)
            self.log('New best model saved with miou',self.maxmiou)

        self.tp, self.fp, self.fn, self.tn = 0,0,0,0

        flattened_prob = torch.flatten(torch.cat(self.val_step_outputs)).cpu().detach()
        try:
            self.logger.experiment.log({
            'valid/sigmoid': wandb.Histogram(flattened_prob),
            'epoch': self.current_epoch
            })
        except Exception as e:
            print(f"Error logging to WandB: {e}")
            
    def test_step(self, batch,batch_idx):
        imgs, masks = batch
        outputs = self(imgs)#batch,channel,height, width
        test_loss = self.combined_losses(outputs, masks)
        self.log('test/test_loss', test_loss, on_step=False, on_epoch=True)

        this_tp, this_fp, this_fn, this_tn = metrics.get_stats(outputs.squeeze(), masks, mode = 'binary', threshold=0.5)

        self.tp += this_tp
        self.fp += this_fp
        self.fn += this_fn
        self.tn += this_tn

        return outputs 
                
    def on_test_epoch_start(self):
        self.start_time = datetime.now()
    def on_test_epoch_end(self):
        
        test_miou = metrics.iou_score(sum(self.tp), sum(self.fp), sum(self.fn), sum(self.tn), reduction = 'micro')
        self.log('test/test_accuracy', test_miou)
        final_time = datetime.now()-self.start_time 
        print('time taken for 1 epoch inference is {}'.format(final_time))
        self.start_time = 0
        self.tp, self.fp, self.fn, self.tn = 0,0,0,0
        
#         #ic(self.trainer.max_epochs-1)
#         #ic(self.current_epoch)
        if self.current_epoch == (self.trainer.max_epochs):
            dummy_input = torch.zeros((1,3,256,256 ), device=self.device)
            model_filename = f"model_{self.name}_ep{self.current_epoch}_test_iou{test_miou:.4f}.onnx"
            torch.onnx.export(self, dummy_input, model_filename, opset_version=11)
            onnx_artifact = wandb.Artifact(name=f"model_{self.name}_onnx_", type="model")
            onnx_artifact.add_file(model_filename)
            self.logger.experiment.log_artifact(onnx_artifact)


In [111]:
class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=3):
        super().__init__()
        self.X_img_samples, self.mask_samples = val_samples
        self.X_img_samples= self.X_img_samples[:num_samples]
        self.mask_samples= self.mask_samples[:num_samples] 

    def on_validation_epoch_end(self, trainer, pl_module):#remember model is now pl_module

        
        self.X_img_samples = self.X_img_samples.to(pl_module.device)
        output_samples = pl_module(self.X_img_samples)


        #output_samples = output_samples*torch.Tensor([0.2588, 0.2734, 0.2997]) + torch.Tensor([0.3606, 0.3771, 0.3724])

        table = wandb.Table(columns = ["images", "predictions", "targets"] 
            )
        for X_img, output, mask in zip(self.X_img_samples.to("cpu"), output_samples.to("cpu"), self.mask_samples.to("cpu")):
            segmap_pred = decode_segmap(output.squeeze().numpy())
            segmap_gt = decode_segmap(mask.numpy())

            table.add_data(wandb.Image(X_img.numpy().transpose(1,2,0)*255), 
                    wandb.Image(segmap_pred), 
                    wandb.Image(segmap_gt)
                    )    

        trainer.logger.experiment.log(
            {'val_images_table': table}
        )

In [112]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Setup the ModelCheckpoint to save the best model based on 'val_loss' (change it as per your metric)

class BestCheckpoint(pl.Callback):
    def __init__(self):
        super().__init__()
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val/val_loss',  # Specify the metric to monitor
            dirpath='/kaggle/working/',  # Directory to save the model
            filename='bin-segmentation-{val_accuracy:02d}-{val_loss:.2f}',
            save_top_k=1,  # Save only the top 1 model which has the best 'val_loss'
            mode='min',  # 'min' mode means lower 'val_loss' is better
            save_weights_only=True,  # If set to False, saves the whole model
            verbose=True)

    def on_train_start(self, trainer, pl_module):
        # Ensure the checkpoint callback is attached to the trainer
        trainer.callbacks.append(self.checkpoint_callback)

from pytorch_lightning.callbacks import EarlyStopping

In [113]:
# sweep_config = {
#     'method': 'random'
#     }

# metric = {
#     'name': 'New best model saved with miou',
#     'goal': 'maximize'
#     }

# sweep_config['metric'] = metric

# parameters_dict = {
#     'batch_size':{
#         'values':[4,8,16,32]
#     },
#     'lr_e': {
#     'distribution': 'log_uniform_values',
#     'min': 5e-3,
#     'max': 5e-1
#     },
#     'lr_d':{
#     'distribution': 'log_uniform_values',
#     'min': 5e-5,
#     'max': 5e-3   
#     },
#     'image_ip_size':{
#         'values': [224,384,512]
#     }
    
#     }

# parameters_dict.update({
#     'epochs':{'value': 15 },
#     'model_name': {'value':'unet'},
#     'encoder_name' :  {'value':'mobilenet_v2'},
#     'encoder_weights' :{'value':'imagenet'},

# })

# sweep_config['parameters'] = parameters_dict
# sweep_id = wandb.sweep(sweep_config, project='multiprocessed dataloader,idd_lite_unet_binary_road_segmenation')

In [115]:
#            ##########################SWEEP RUNNING#####################################
# def train_using_wandb():
#     run = wandb.init(project = 'multiprocessed dataloader,idd_lite_unet_binary_road_segmenation',
#                 config = wandb.config
#                 )
#     config = run.config
#     run_name = f' 1 gpu lr {config.lr_d:.4f}, epochs {config.epochs}, batch_size: {config.batch_size}'
#     wandb.run.name = run_name

#     datamod = dmidd(batch_size=config.batch_size, size = config.image_ip_size)
#     datamod.setup()

#     model = BinarySegmentationForIdd(model_name= config.model_name,
#                                      encoder_name = config.encoder_name,
#                                      encoder_weights = config.encoder_weights,
#                                      lr_e = config.lr_e,
#                                      lr_d = config.lr_d)       

#     logger = WandbLogger()
#     wandb.watch(model, model.loss_function, log= 'all', log_freq = 160 )
#     val_samples = next(iter(datamod.val_dataloader()))


#     trainer = pl.Trainer(
#         accelerator="gpu", devices=1,
#         logger = logger,
#         log_every_n_steps = 1,
#         max_epochs = config.epochs,
#         callbacks = [ImagePredictionLogger(val_samples)]#, gpu_stats]
#     )
    
#     trainer.fit(model,datamod)
#     trainer.test(datamodule = datamod, ckpt_path='best')


In [116]:
      ################################# RUN RUNNING ###################################
import torchinfo

def train_using_wandb():
    run = wandb.init(project = 'Trying different smp models on iddlite',
                config = {'model_name':'unet',
                          'encoder_name':'mobilenet_v2',
                          'encoder_weights':'imagenet',
                          'lr_e': 0.015,
                          'lr_d': 0.001,
                          'epochs':15,
                          'batch_size':32,
                          'image_ip_size':256,
                          #'accumulate_grad_batches':   
                         }
                )
    config = run.config
    run_name = f' unet with scse batch_size: {config.batch_size}, lr {config.lr_e:.4f}, epochs {config.epochs}'
    wandb.run.name = run_name

    datamod = dmidd(batch_size=config.batch_size, size = config.image_ip_size)
    datamod.setup()

    model = BinarySegmentationForIdd(model_name= config.model_name,
                                     encoder_name = config.encoder_name,
                                     encoder_weights = config.encoder_weights,
                                     lr_e = config.lr_e,
                                     lr_d = config.lr_d)
    
    logger = WandbLogger()
    wandb.watch(model, model.combined_losses, log= 'all', log_freq = 1800 )#log every 360th batch, the grad, weights
    val_samples = next(iter(datamod.val_dataloader()))


    trainer = pl.Trainer(
#         gradient_clip_val=0.5,
        accelerator="gpu", devices=1,
        logger = logger,
        log_every_n_steps = 1,
        max_epochs = config.epochs,
        callbacks = [ImagePredictionLogger(val_samples), BestCheckpoint(),EarlyStopping(monitor='val/val_loss',patience=5, mode='min')]#, gpu_stats]
        #default_root_dir=''
    )
    trainer.fit(model,datamod)
    trainer.test(datamodule = datamod, ckpt_path='best')
    wandb.finish()

In [117]:
train_using_wandb()

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.5000e-02.
Adjusting learning rate of group 1 to 1.0000e-03.


Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

time taken for 1 epoch inference is 0:00:00.352015


VBox(children=(Label(value='215.845 MB of 215.845 MB uploaded (2.930 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
New best model saved with miou,▁▄▆▇▇█
epoch,▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇███
test/test_accuracy,▁
test/test_loss,▁
train/train_loss_epoch,█▄▃▃▂▂▂▂▁▁▁▁▁▁▁
train/train_loss_step,█▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val/val_accuracy,▁▄▆▅▇▇█▆▇█▂▇▇▇▇
val/val_loss,█▄▃▃▂▂▂▂▁▁▂▁▁▁▁

0,1
New best model saved with miou,0.91301
epoch,15.0
test/test_accuracy,0.93114
test/test_loss,0.12569
train/train_loss_epoch,0.15469
train/train_loss_step,0.15168
trainer/global_step,645.0
val/val_accuracy,0.90008
val/val_loss,0.15876


In [13]:
from idd_lite_helpers.idd_lite_helpers import IDD_Main_Dataset
from idd_lite_helpers.idd_lite_helpers import IDDRoadSegmentationDatamodule as dmidd

# /kaggle/usr/lib/idd_lite_helpers/idd_lite_helpers.py
from idd_lite_helpers.idd_lite_helpers import count
print(count)

4


In [68]:
import torchinfo
datamod = dmidd(batch_size=1, size = 256)
datamod.setup()

model = BinarySegmentationForIdd(model_name= 'unet',
                                 encoder_name = 'mobilenet_v2',
                                 encoder_weights ='imagenet',
                                 lr_e = 0.015,
                                 lr_d = 0.001)

path = '/kaggle/input/unet-mobilenetv2/pytorch/91miou/1/unet_mobilenet_v2_accuracy0.9138.pth'
checkpoint = torch.load(path)

state_dict = checkpoint['state_dict']
model.load_state_dict(checkpoint, strict = False)


model = model.to('cuda')                 
torchinfo.summary(model, input_size=(1,3, 256, 256))

Layer (type:depth-idx)                                            Output Shape              Param #
BinarySegmentationForIdd                                          [1, 1, 256, 256]          --
├─Unet: 1-1                                                       [1, 1, 256, 256]          --
│    └─MobileNetV2Encoder: 2-1                                    [1, 3, 256, 256]          --
│    │    └─Sequential: 3-1                                       --                        2,223,872
│    └─UnetDecoder: 2-2                                           [1, 16, 256, 256]         --
│    │    └─Identity: 3-2                                         [1, 1280, 8, 8]           --
│    │    └─ModuleList: 3-3                                       --                        4,671,553
│    └─SegmentationHead: 2-3                                      [1, 1, 256, 256]          --
│    │    └─Conv2d: 3-4                                           [1, 1, 256, 256]          145
│    │    └─Identity: 3-5     

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [59]:
device

device(type='cuda')

In [60]:
img, msk = next(iter(datamod.test_dataloader()))

In [61]:
print(device)

cuda


In [62]:
print(img.shape)
print(img.dtype)
img = img.to(device)
print(img.dtype)
print(model.dtype)
model.to(device);

torch.Size([1, 3, 256, 256])
torch.float32
torch.float32
torch.float32


In [63]:
%%timeit
pred_mask = model(img)

14.5 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [64]:
%%timeit
pred_mask = model(img)

14.8 ms ± 214 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [65]:
def process_images(model, image, h, w):
  
    image = cv2.resize(image,(224,224))
    
    #print(image.shape)
    image_tensor = torch.tensor(image, dtype=torch.float32)
    image_tensor = image_tensor / 255.0  # Normalize to [0, 1]
    image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0).to(device) 
    model.eval()
    with torch.inference_mode():
        pred_mask = model(image_tensor)
        
    pred_mask = pred_mask.cpu().numpy()
    mask = pred_mask.squeeze()>0.5
    zero_image = np.zeros_like(mask)
    mask = np.stack((mask, mask, mask), axis=-1)*255
    mask = np.asarray(mask, np.uint8)
#     print("Image shape:", image.shape)
#     print("Mask shape:", mask.shape)
    
    
    final_image = cv2.addWeighted(image, 1.0,mask,0.5,0.0)
    final_image = cv2.resize(final_image, (w, h))
    return final_image

    
    

In [67]:
path = '/kaggle/input/inf-vid-2/A_ one_ minute_tour_of_RIT_2k17_(www.KeepVid.to)_BIG.mp4'
vid_object = cv2.VideoCapture(path)
frame_width = int(vid_object.get(3))
frame_height = int(vid_object.get(4))

fourcc = cv2.VideoWriter_fourcc('m','p','4','v')
fps =vid_object.get(cv2.CAP_PROP_FPS)
print(fps)
output = cv2.VideoWriter(
          '/kaggle/working/unetp_inf_p100.mp4',
          fourcc,
          fps,
          (frame_width,frame_height)
        )

while(vid_object.isOpened()):
    
    ret, frame = vid_object.read()
    if ret == True:
#         print('working')
        tqdm(output.write(process_images(model,frame,frame_height, frame_width)))
    else:
        break
vid_object.release()
output.release()


29.85999500845624


0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, 