# Task 1 : Set up colab gpu runtime environment

In [None]:
!pip install segmentation-models-pytorch
!pip install -U git+https://github.com/albumentations-team/albumentations
!pip install --upgrade opencv-contrib-python

# Download Dataset

original author of the dataset :

https://github.com/VikramShenoy97/Human-Segmentation-Dataset

In [None]:
!git clone https://github.com/parth1620/Human-Segmentation-Dataset-master.git

# Import Libraries 

In [2]:
import sys
sys.path.append('/content/Human-Segmentation-Dataset-master')

In [3]:
import torch 
import cv2

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt 

from sklearn.model_selection import train_test_split
from tqdm import tqdm

import helper

#  Setup Configurations

In [None]:
CSV_FILE = '/content/Human-Segmentation-Dataset-master/train.csv' # Contains the images & corresponding masks
DATA_DIR = '/content/'

DEVICE = 'cuda'

EPOCHS = 25
LR = 0.003
IMAGE_SIZE = 320
BATCH_SIZE = 16

ENCODER = 'timm-efficientnet-b0' 
WEIGHTS = 'imagenet'

In [None]:
df = pd.read_csv(CSV_FILE)
df.head()

# Inspect data

In [None]:
row = df.iloc[10]
image_path = row.images
mask_path = row.masks

In [None]:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)/ 255.0

In [None]:
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
        
ax1.set_title('IMAGE')
ax1.imshow(image)

ax2.set_title('GROUND TRUTH')
ax2.imshow(mask,cmap = 'gray')

**Split the data**

In [None]:
train_df, valid_df =  train_test_split(df,test_size = 0.2, random_state = 42)

# Data Augmentation Functions

For instance and semantic segmentation tasks, you need to augment both the input image and one or more output masks.

Albumentations ensures that the input image and the output mask will receive the same set of augmentations with the same parameters.



albumentation documentation : https://albumentations.ai/docs/

In [None]:
import albumentations as A

In [None]:
def get_train_augs():
  return A.Compose([
      A.Resize(IMAGE_SIZE, IMAGE_SIZE),
      A.HorizontalFlip(p=0.5),
      A.VerticalFlip(p=0.5)
  ],is_check_shapes = False)

def get_valid_augs():
  return A.Compose([
      A.Resize(IMAGE_SIZE, IMAGE_SIZE)
  ],is_check_shapes = False)

# **Create Custom Dataset**

We need to create a pytorch dataset to load image & mask in pairs. 

In [None]:
from torch.utils.data import Dataset

In [None]:
class SegmentationDataset(Dataset):

  def __init__(self, df, augmentations):
    self.df = df
    self.augmentations = augmentations

  def __len__(self):
    return len(self.df)

  #this returns the image & mask pairs according the the index

  def __getitem__(self, index):

    row = self.df.iloc[index]

    image_path = row.images
    mask_path = row.masks

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # (H,W,C)
    

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # (H, W)
    mask = np.expand_dims(mask, axis = -1) # (H,W,C)

    if self.augmentations:
      #print(image.shape, mask.shape, image_path)
      data = self.augmentations(image = image, mask = mask)
      image= data['image']
      mask = data['mask']
      

    #pytorch expects (C,H,W) we have in (H,W,C)
    image = np.transpose(image, (2,0,1)).astype(np.float32)
    mask = np.transpose(mask,(2,0,1)).astype(np.float32)

    image = torch.Tensor(image)/255.0
    mask = torch.round(torch.Tensor(mask)/255.0)

    return image, mask

In [None]:
trainset = SegmentationDataset(train_df, get_train_augs())
validset =  SegmentationDataset(valid_df, get_valid_augs())

In [None]:
print(f"Size of Trainset : {len(trainset)}")
print(f"Size of Validset : {len(validset)}")

In [None]:
'''
An utulity function to display image &  mask
'''
import matplotlib.pyplot as plt 
import numpy as np 
import torch

def show_image(image,mask,pred_image = None):
    
    if pred_image == None:
        
        f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
        
        ax1.set_title('IMAGE')
        ax1.imshow(image.permute(1,2,0).squeeze(),cmap = 'gray')
        
        ax2.set_title('GROUND TRUTH')
        ax2.imshow(mask.permute(1,2,0).squeeze(),cmap = 'gray')
        
    elif pred_image != None :
        
        f, (ax1, ax2,ax3) = plt.subplots(1, 3, figsize=(10,5))
        
        ax1.set_title('IMAGE')
        ax1.imshow(image.permute(1,2,0).squeeze(),cmap = 'gray')
        
        ax2.set_title('GROUND TRUTH')
        ax2.imshow(mask.permute(1,2,0).squeeze(),cmap = 'gray')
        
        ax3.set_title('MODEL OUTPUT')
        ax3.imshow(pred_image.permute(1,2,0).squeeze(),cmap = 'gray')
        
        


In [None]:
idx = 55
image, mask = trainset[idx]
show_image(image, mask) # helper.show_image(image, mask)

# Load dataset into batches

In [None]:
from torch.utils.data import DataLoader

In [None]:
trainloader = DataLoader(trainset, batch_size = BATCH_SIZE, shuffle=True)
validloader = DataLoader(validset, batch_size = BATCH_SIZE)

In [None]:
print(f'Number of batches in tranloader : {len(trainloader)}')
print(f'Number of batches in validloader : {len(validloader)}')

In [None]:
for image,mask in trainloader:
  break

In [None]:
print(f'Image batch shape {image.shape}')
print(f'Mask batch shape {mask.shape}')

# Create Segmentation Model

segmentation_models_pytorch documentation : https://smp.readthedocs.io/en/latest/

In [None]:
from torch import nn
import segmentation_models_pytorch as smp 
from segmentation_models_pytorch.losses import DiceLoss

In [None]:
class SegmenetationModel(nn.Module):
  
  def __init__(self):
    super(SegmenetationModel, self).__init__()
    self.arc = smp.Unet(
        encoder_name= ENCODER,
        encoder_weights=WEIGHTS,
        in_channels=3, 
        classes= 1, 
        activation=None,
    )

  def forward(self,images, masks=None):
    logits = self.arc(images)
    

    if masks !=None:
      loss1 = DiceLoss(mode='binary')(logits,masks)
      loss2 = nn.BCEWithLogitsLoss()(logits, masks)
      return logits, loss1 + loss2

    return logits

In [None]:
model = SegmenetationModel()
model.to(DEVICE);

## Create Train and Validation Function

In [None]:
def train_fn(dataloader, model, optimizer):
  model.train()
  total_loss = 0

  for images, masks in tqdm(dataloader):
    
    images = images.to(DEVICE)    
    masks = masks.to(DEVICE)   

    optimizer.zero_grad()
    logits , loss = model(images,masks)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  return total_loss/len(dataloader)

In [None]:
def eval_fn(dataloader, model, optimizer):
  model.eval()
  total_loss = 0

  with torch.no_grad():
    for images, masks in tqdm(dataloader):
      images = images.to(DEVICE)  
      masks = masks.to(DEVICE)  

      logits, loss = model(images, masks)
      total_loss += loss.item()

  return total_loss/len(dataloader)

## Train Model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr= LR)

In [None]:
best_valid_loss = np.Inf

for i in range(EPOCHS):
  train_loss = train_fn(trainloader,model, optimizer)
  valid_loss = eval_fn(validloader,model,optimizer)

  if valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    torch.save(model.state_dict(),'best_model.pt')
    print('MODEL SAVED!!')
  print(f'Epoch: {i+1} Train Loss:{train_loss} Valid Loss{valid_loss}')

# Inference

In [None]:
idx = 5
model.load_state_dict(torch.load('/content/best_model.pt'))

image, mask = validset[idx]
logits_mask = model(image.to(DEVICE).unsqueeze(0)) # C,H, W -> (1,C,H,W) includes the batch dimension
pred_mask = torch.sigmoid(logits_mask)
pred_mask = (pred_mask > 0.5) *1.0 

In [None]:
show_image(image,mask,pred_mask.detach().cpu().squeeze(0))