In [6]:
import torch
import numpy as np
import pandas as pd
import os
import time

from PIL import Image
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam, AdamW, SGD

In [3]:
!git clone https://github.com/VikramShenoy97/Human-Segmentation-Dataset.git

Cloning into 'Human-Segmentation-Dataset'...
remote: Enumerating objects: 596, done.[K
remote: Total 596 (delta 0), reused 0 (delta 0), pack-reused 596 (from 1)[K
Receiving objects: 100% (596/596), 13.60 MiB | 35.44 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [13]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((512,512)),
            transforms.ToTesnsor()
        ])

        valid_extension = {".jpg",".jpeg",".png"}
        self.images = [f for f in os.listdir(image_dir) if os.path.splittext(f)[1].lower() in valid_extension]

    def __len(self):
        return len(self.images)

    def __getitem(self, idx):
        image_path= os.path.join(self.image_dir, self.images[idx])
        name, ext = os.path.splitext(self.images[idx])
        image_path= os.path.join(self.mask_dir, f"{name}.png")

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        mask = (mask >0.5).float()

        return image, mask


In [14]:
## Data loader 

def get_dataloader(image_dir, maks_dir, batch_size =2, shuffle =True):
    dataset = SegmentationDataset(image_dir, maks_dir)
    return DataLoader(dataset, batch_size = batch_size, shuffle =shuffle)


In [17]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding =1),
            nn.ReLU(inplace =True),
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding =1),
            nn.ReLU(inplace =True)
        )
    def forward(self, x):
        return self.conv_op(x)
        

In [18]:
## this downsample need to take care of as two conv and one max pooling

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool =nn.MaxPool2d(kernel_size =2, stride =2)
    def forward(self, x):
        down = self.conv(x) ## convolve, this need be saved for upsampling
        p = self.pool(down) ## the pooling stuff
        return down, p

In [19]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size =2, stride =2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2],1)
        return self.conv(x)
        

In [21]:
### Unet

class Unet(nn.Module):
    def __init__(self, in_channels, num_classes):
        self.down_conv1= DownDownSample(in_channels, 64) ## output is 64
        self.down_conv2= DownDownSample(64, 128)
        self.down_conv3= DownDownSample(128,256)
        self.down_conv4= DownDownSample(256,512) 

        self.bottle_neck = DoubleConv(512, 1024)
        
        self.up_conv1= DownDownSample(1024, 512) ## output is 64
        self.up_conv2= DownDownSample(512, 256)
        self.up_conv3= DownDownSample(256,128)
        self.up_conv4= DownDownSample(128,64)

        self.out =nn.Conv2d(in_channels = 64, out_channels= num_classes, kernel_size =1)

    def forward(self,x):
        down_1, p1 = self.down_conv1(x)
        down_2, p2 = self.down_conv2(p1)
        down_3, p3 = self.down_conv3(p2)
        down_4, p4 = self.down_conv4(p3)

        b = self.bottle_neck(p4)

        up1= self.up_conv1(b,down_4)
        up2= self.up_conv2(up1,down_3)
        up3= self.up_conv3(up2,down_2)
        up4= self.up_conv4(up3,down_1)

        out =self.out(up_4)
        return out
        

        

In [22]:
## 

class DiceLoss(nn.Module):
    def __init__(self, smooth = 1e-16):
        super(DiceLoss, self).__init__()
        self.smooth =smooth
    def forward(self,inputs, targets):
        inputs = inputs.view(-1)
        targets =targets.view(-1)

        intersection(inputs * targets).sum()
        dice_score = (2 * intersection + self.smooth)/(inputs.sum() + targets.sum() + self.smooth)

        return 1 - dice_score



class BCEWithDiceLoss(nn.Module): 
    def __init__(self, smooth =1e-6):
        super(BCEWithDiceLoss, self).__init__()
        self.bce =nn.BCEWithDiceLoss()
        self.dice = DiceLoss()

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        dice_loss =self.dice(inputs, targets)
        return 0.5 * bce_loss + dice_loss
    

In [None]:
## Training Loss

def train(model, dataloader, epochs =2, lr =0.001, save_path  ="unet_model", load_path = None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if load_path and os.path.exists(load_path):
        print(f"Loading model weights  form{load_path}")
        model.load_state_dict(torch.load(load_path, map_location =device))
    else : 
        print(f"no checkpoint found, training from scratch..")

    print(device)
    model.to(device)

    criterion = BCEWithDiceLoss()

    optimizer = SGD(model.parameters(), lr =lr)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for images, masks, in dataloader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()

            out = model(images)
             loss