In [1]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as tt
import torchvision
from abc import ABC, abstractmethod
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2 as cv
import glob
from typing import Union

In [2]:
img_path = "/kaggle/input/segmentation-full-body-mads-dataset/segmentation_full_body_mads_dataset_1192_img/segmentation_full_body_mads_dataset_1192_img/images"
mask_path = "/kaggle/input/segmentation-full-body-mads-dataset/segmentation_full_body_mads_dataset_1192_img/segmentation_full_body_mads_dataset_1192_img/masks"
img_files = sorted(glob.glob(img_path + "/*"))
mask_files = sorted(glob.glob(mask_path + "/*"))

In [3]:
def plot_example(index):
    img_input = np.array(Image.open(img_files[index]))
    mask = np.array(Image.open(mask_files[index]))
    print(mask.shape)
    fig, ax = plt.subplots(ncols=2, figsize=(12,8))
    ax[0].imshow(img_input)
    ax[1].imshow(mask)
    ax[0].axis("off")
    ax[1].axis("off")
    plt.show()

In [4]:
mean = [0.485, 0.456, 0.406]
std =[0.229, 0.224, 0.225]

class Humans(Dataset):
    def __init__(self, img_paths, label_paths, H=1500, transforms=None):
        self.img_paths = img_paths
        self.label_paths = label_paths
        self.transforms = transforms
        self.H = H
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, index):
        img_input = (Image.open(self.img_paths[index]))
        mask = (Image.open(self.label_paths[index]))
        
        img_transforms = [
            tt.ToTensor(),
            tt.Resize((self.H, self.H), antialias=True),
            tt.Normalize(mean,std)
        ]
        
        if self.transforms is not None:
            img_transforms = img_transforms + self.transforms
        
        img_transforms = tt.Compose(img_transforms)
        transforms_mask = tt.Compose([
            tt.Grayscale(),
            tt.Resize((self.H, self.H), antialias=True),
            tt.ToTensor()
        ])
        mask_tensor = (transforms_mask(mask) > 0.05).float()
        img_tensor = img_transforms(img_input)
        
        return img_tensor, mask_tensor

In [5]:
class SegLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def dice_loss(self, pred: torch.Tensor, gt: torch.tensor):
        intersection = torch.sum((pred * gt))
        epsilon = 1e-8
        union = (pred + gt).sum() + epsilon
        return 1 - ((2 * intersection) / union)
    
    def forward(self, gt, pred):
        
        dice_loss = 0
        for cl in range(pred.shape[1]):
            dice_loss += self.dice_loss(pred[:,cl,:,:], gt.eq(cl).float())
        
        dice_loss /= pred.shape[1]        
        return dice_loss

In [6]:
def fit(epochs: int, optimizer: torch.optim, model: nn.Module, learning_rate: float,
       train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, 
        learning_rate_scheduler: torch.optim.lr_scheduler,
        model_name="best_model.pt", **kwargs):
    
    optimizer = optimizer(model.parameters(), lr=learning_rate)   
    lrs = learning_rate_scheduler(optimizer, **kwargs)
    history = []
    min_val_loss = 10e20
    old_lr = learning_rate
    for epoch in range(epochs): 
        #model.select_loss(epoch)
        for num, batch in enumerate(train_loader):
            train_losses = []
            optimizer.zero_grad()
            loss = model.training_step(batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_losses.append(loss.detach())
            print(f"Batch num [{num}]: loss {loss})")
        
        
        result = model.evaluate(val_loader)
        result["train_loss"] = torch.stack(train_losses).detach().mean().item()
        model.epoch_end_val(epoch, result)
        history.append(result)
        lrs.step(result["val_loss"])
        
        if optimizer.param_groups[0]["lr"] != old_lr:
            old_lr = optimizer.param_groups[0]["lr"]
            print(f"Updated learning rate to {old_lr:.4f}")
        
        if result["val_loss"] < min_val_loss:
            torch.save(model, model_name)
            min_val_loss = result["val_loss"]

    return history       


In [7]:
class ModelBase(ABC):
    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def forward(self):
        pass

    def training_step(self, batch):
        """
        Training step which computes the loss and accuracy of a train batch
        :param batch: batch of pytorch dataloader
        :type batch: torch.utils.data.DataLoader
        :return: loss, accuracy and f1_score of batch
        :rtype: tuple[torch.tensor,...]
        """
        # Runs the forward pass with autocasting.
        #with torch.cuda.amp.autocast():
        self.train()
        images, target_mask = batch
        prediction_mask = self(images)
        train_loss = self.loss(target_mask, prediction_mask)
        return train_loss

    def validation_step(self, batch):
        self.eval()
        with torch.no_grad():
            images, target_mask = batch
            prediction_mask = self(images)
            val_loss = self.loss(target_mask, prediction_mask)
            return {"val_loss": val_loss}

    def validation_epoch_end(self, outputs):
        """
        Returns the epoch losses after computing the mean loss and accuracy of the test batches

        :param outputs: List of test step outputs
        :type outputs: list
        :return: epoch loss and epoch accuracy
        :rtype: dict
        """
        batch_losses = [x["val_loss"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean().item()

        return {"val_loss": epoch_loss}

    def evaluate(self, dl):
        outputs = [self.validation_step(batch) for batch in dl]
        return self.validation_epoch_end(outputs)

    def epoch_end_val(self, epoch, results):
        """
        Prints validation epoch summary after every epoch

        :param epoch: epoch number
        :type epoch: int
        :param results: results from the evaluate method
        :type results: dictionary
        :return: None
        """

        print(
            f"Epoch:[{epoch}]: |validation loss: {results['val_loss']}|"
        )

In [8]:
class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling (ASPP) layer as described in DeepLab architectures.
    This layer uses multiple dilated convolutions to capture multi-scale information.
    
    :param in_ch: Number of input channels.
    
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        
        
        # Convolution with dilation rate of 1 (normal convolution)
        self.aconv1 = nn.Conv2d(in_ch, 256, 3, dilation=1, padding="same")
        
        # Convolution with dilation rate of 6
        self.aconv2 = nn.Conv2d(in_ch, 256, 3, dilation=6, padding="same")
        
        # Convolution with dilation rate of 12
        self.aconv3 = nn.Conv2d(in_ch, 256, 3, dilation=12, padding="same")
        
        # Convolution with dilation rate of 18
        self.aconv4 = nn.Conv2d(in_ch, 256, 3, dilation=18, padding="same")
        
        # Convolution with dilation rate of 24
        self.aconv5 = nn.Conv2d(in_ch, 256, 3, dilation=24, padding="same")
        
        # Batch normalization for concatenated feature maps
        self.bn = nn.BatchNorm2d(256 * 5)
        
        # ReLU activation
        self.relu = nn.ReLU(inplace=True)
        
        # Prediction convolution
        self.pred_conv = nn.Conv2d(256 * 5 ,out_ch, 1, padding="same")

    def forward(self, x):
        """
        Forward pass through the ASPP layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after ASPP.
        """
        
        # Pass the input through each dilated convolution and apply ReLU
        out1 = self.relu(self.aconv1(x))
        out2 = self.relu(self.aconv2(x))
        out3 = self.relu(self.aconv3(x))
        out4 = self.relu(self.aconv4(x))
        out5 = self.relu(self.aconv5(x))
        
        # Concatenate the outputs along channel dimension
        cat = torch.cat((out1, out2, out3, out4, out5), dim=1)
        
        # Apply batch normalization
        out = self.bn(cat)
        
        # Apply the prediction convolution
        pred = self.pred_conv(out)
        
        return pred


In [9]:
class DeepLabv3(nn.Module, ModelBase):
    def __init__(self, loss, final_softmax=True):
        super().__init__()
        backbone = torchvision.models.resnet101(
            weights=torchvision.models.ResNet101_Weights.DEFAULT
        )
        self.backbone = nn.Sequential(*list(backbone.children())[:-3])
        self.aspp = ASPP(1024, 2)
        self.up = nn.Upsample(scale_factor = 16, mode="bilinear")
        
        if final_softmax:
            self.sm = nn.Softmax2d()
        else:
            self.sm = nn.Identity()
            
        self.loss = loss
        
    def forward(self, x):
        x = self.backbone(x)
        out = self.aspp(x)
        out = self.sm(out)
        up = self.up(out)
        
        return up
    
    def select_loss(self, epoch):
        if epoch < 10:
            self.loss = SegLoss(0)
        
        else:
            self.loss = SegLoss(0)

In [10]:
class DeepLabv3Plus(nn.Module, ModelBase):
    def __init__(self, loss, final_softmax=True):
        super().__init__()
        backbone = torchvision.models.resnet101(
            weights=torchvision.models.ResNet101_Weights.DEFAULT
        )
        self.backbone = nn.Sequential(*list(backbone.children())[:-3])
        self.aspp = ASPP(1024, 256)
        self.onebyone_conv = nn.Conv2d(1024,256,1, padding="same")
        self.low_level_upsampler = nn.Upsample(scale_factor=4, mode="bilinear")
        self.up_encoder = nn.Upsample(scale_factor=4, mode="bilinear")
        self.up_backbone = nn.Upsample(scale_factor=4, mode="bilinear")
        self.decoder_conv = nn.Conv2d(512, 2, 3, padding="same")
        self.up_final = nn.Upsample(scale_factor=4, mode="bilinear")
        
        if final_softmax:
            self.sm = nn.Softmax2d()
        else:
            self.sm = nn.Identity()
            
        self.loss = loss 
        
    def forward(self, x):
        x = self.backbone(x)
        low_level_features = self.onebyone_conv(x)
        low_level_features_up = self.low_level_upsampler(low_level_features)
        encoder_out = self.aspp(x)
        encoder_out_up = self.up_encoder(encoder_out)
        

        combined = torch.cat((low_level_features_up, encoder_out_up), dim=1)
        out = self.decoder_conv(combined)
        out = self.sm(out)
        up = self.up_final(out)
        
        return up

In [11]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ELU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ELU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear")
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2=None):
        x1 = self.up(x1)
        if x2 is not None:
            x = torch.cat([x2, x1], dim=1)
            # input is CHW
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x = F.pad(x, (diffX // 2, diffX - diffX // 2), (diffY // 2, diffY - diffY // 2))

        else:
            x = x1

        x = self.conv(x)
        return(x)
    
    
class UnetResNet32(nn.Module, ModelBase):
    def __init__(self, loss, final_softmax: bool = True, 
                 num_classes=2,train_backbone: bool = True):
        """
        UnetResNet32 class.

        :param final_resolution: The final resolution of the heatmap
        :param num_classes: The number of output classes. Defaults to 1.
        :param model_name: The name of the backbone model. Defaults to "resnet18".
        """
        nn.Module.__init__(self)

        basemodel = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
        self.basemodel = nn.Sequential(*list(basemodel.children())[:-2])

        for param in self.basemodel.parameters():
            param.requires_grad = train_backbone

        self._NUM_CLASSES = num_classes
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Instantiate hooks
        # Keep the outputs of the intermediate layers
        self.outputs = {}

        for i, layer in enumerate(list(self.basemodel.children())):
            layer.register_forward_hook(self.save_output)

        self.maxpool = nn.MaxPool2d(2)

        # Create decoding part
        self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up5 = nn.ConvTranspose2d(64, 32, 2, stride=2)

        self.up_conv4 = DoubleConv(256 + 256, 256)
        self.up_conv3 = DoubleConv(2 + 128, 128)
        self.up_conv3 = DoubleConv(128 + 128, 128)
        self.up_conv2 = DoubleConv(64 + 64, 64)
        self.up_conv1 = DoubleConv(32 + 64, 64)

        self.outc3 = nn.Conv2d(128, self._NUM_CLASSES, kernel_size=1)
        self.outc2 = nn.Conv2d(64, self._NUM_CLASSES, kernel_size=1)
        self.outc1 = nn.Conv2d(64, self._NUM_CLASSES, kernel_size=1)
        
        if final_softmax:
            self.sm = nn.Softmax2d()
        else:
            self.sm = nn.Identity()
            
        self.loss = loss 

    def save_output(self, module, input, output):
        self.outputs[module] = output

    def forward(self, x):
        self.basemodel(x)

        x = self.up2(self.outputs[list(self.basemodel.children())[-1]])
        x = torch.cat((x, self.outputs[list(self.basemodel.children())[-2]]), dim=1)
        x = self.up_conv4(x)

        x = self.up3(x)
        x = torch.cat((x, self.outputs[list(self.basemodel.children())[-3]]), dim=1)
        x = self.up_conv3(x)
        output_first_map = nn.Upsample(scale_factor=8, mode="bilinear")(self.outc3(x))

        x = self.up4(x)
        x = torch.cat((x, self.outputs[list(self.basemodel.children())[-4]]), dim=1)
        x = self.up_conv2(x)
        intermediate_output_second = nn.Upsample(scale_factor=4, mode="bilinear")(self.outc2(x))
        output_second_map = torch.add(output_first_map, intermediate_output_second)

        x = self.up5(x)
        x = torch.cat((x, self.outputs[list(self.basemodel.children())[-6]]), dim=1)
        x = self.up_conv1(x)
        intermediate_output_third = nn.Upsample(scale_factor=2, mode="bilinear")(self.outc1(x))
        final_output = torch.add(output_second_map, intermediate_output_third)
        final_output = self.sm(final_output)

        return final_output

In [12]:
def to_device(data,device):
    if isinstance(data, (list,tuple)):
        return [x.to(device) for x in data]
    return data.to(device)

class DeviceDataLoader:
    def __init__(self, dl): 
        self.dl = dl
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    def __iter__(self):
        for batch in self.dl:
            yield to_device(batch, self.device)
    
    def __len__(self):
        return len(self.dl)

In [15]:
transforms = [
    tt.RandomAdjustSharpness(0.3),
    tt.RandomAutocontrast(),
    tt.RandomGrayscale(),
    tt.RandomApply([
        tt.ColorJitter()
    ])]

ds = Humans(img_files, mask_files, 512, transforms)
ds_train, ds_val = random_split(ds, (0.88,0.12))
train_loader = DeviceDataLoader(DataLoader(ds_train, num_workers=2, batch_size=8, shuffle=True))
val_loader = DeviceDataLoader(DataLoader(ds_val, num_workers=2, batch_size=8, shuffle=False))

ValueError: num_samples should be a positive integer value, but got num_samples=0