In [1]:
!git clone https://github.com/VikramShenoy97/Human-Segmentation-Dataset

Cloning into 'Human-Segmentation-Dataset'...
remote: Enumerating objects: 596, done.[K
remote: Total 596 (delta 0), reused 0 (delta 0), pack-reused 596 (from 1)[K
Receiving objects: 100% (596/596), 13.60 MiB | 43.79 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [2]:
import os
import time
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam, AdamW, SGD

In [28]:
class SegmentationDataset(Dataset):
  def __init__(self, image_dir, mask_dir):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transforms.Compose([
        transforms.Resize((512,512)),
        transforms.ToTensor()
    ])

    valid_extension = {".jpg", ".jpeg", ".png"}
    self.images = [f for f in os.listdir(image_dir) if os.path.splitext(f)[1].lower() in valid_extension]

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_path = os.path.join(self.image_dir, self.images[idx])
    name, ext = os.path.splitext(self.images[idx])
    mask_path = os.path.join(self.mask_dir, f"{name}.png")

    image = Image.open(img_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")

    image = self.transform(image)
    mask = self.transform(mask)

    mask = (mask > 0.5).float()

    return image, mask

In [29]:
def get_dataloader(image_dir, mask_dir, batch_size=2, shuffle=True):
  dataset = SegmentationDataset(image_dir, mask_dir)
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [40]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv_op = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.conv_op(x)

In [41]:
class DownSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = DoubleConv(in_channels, out_channels)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    down = self.conv(x)
    p = self.pool(down)

    return down, p

In [42]:
class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
    self.conv = DoubleConv(in_channels, out_channels)

  def forward(self, x1, x2):
    x1 = self.up(x1)
    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)

In [43]:
class UNet(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()
    self.down_convolution_1 = DownSample(in_channels, 64)
    self.down_convolution_2 = DownSample(64, 128)
    self.down_convolution_3 = DownSample(128, 256)
    self.down_convolution_4 = DownSample(256, 512)

    self.bottle_neck = DoubleConv(512, 1024)

    self.up_convolution_1 = UpSample(1024, 512)
    self.up_convolution_2 = UpSample(512, 256)
    self.up_convolution_3 = UpSample(256, 128)
    self.up_convolution_4 = UpSample(128, 64)

    self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

  def forward(self, x):
    down_1, p1 = self.down_convolution_1(x)
    down_2, p2 = self.down_convolution_2(p1)
    down_3, p3 = self.down_convolution_3(p2)
    down_4, p4 = self.down_convolution_4(p3)

    b = self.bottle_neck(p4)

    up_1 = self.up_convolution_1(b, down_4)
    up_2 = self.up_convolution_2(up_1, down_3)
    up_3 = self.up_convolution_3(up_2, down_2)
    up_4 = self.up_convolution_4(up_3, down_1)

    out = self.out(up_4)

    return out

In [44]:
class DiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super(DiceLoss, self).__init__()
    self.smooth = smooth

  def forward(self, inputs, targets):
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum()
    dice_score = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

    return 1 - dice_score

In [45]:
class BCEWithDiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super(BCEWithDiceLoss, self).__init__()
    self.bce = nn.BCEWithLogitsLoss()
    self.dice = DiceLoss(smooth)

  def forward(self, inputs, targets):
    bce_loss = self.bce(inputs, targets)
    dice_loss = self.dice(inputs, targets)

    return bce_loss + dice_loss

In [46]:
#Training Loop

def train(model, dataloader, epochs=2, lr=0.001, save_path="unet_model", load_path=None):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  if load_path and os.path.exists(load_path):
    print(f"Loading model weights from {load_path}")
    model.load_state_dict(torch.load(load_path), map_location=device)
  else:
    print(f"No previous model found, training from scratch")

  print(device)
  model.to(device)

  criterion = BCEWithDiceLoss()
  optimizer = SGD(model.parameters(), lr=lr)

  for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for images, masks in dataloader:
      images = images.to(device)
      masks = masks.to(device)
      optimizer.zero_grad()

      output = model(images)

      loss = criterion(output, masks)
      loss.backward()
      optimizer.step()

      epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    #save every 10th epoch
    if epoch % 10 == 0 and epoch > 0:
      torch.save(model.state_dict(), f"{save_path}.pth")

  #Save final
  torch.save(model.state_dict(), f"{save_path}_final.pth")
  print(f"Model Saved to {save_path}")

In [47]:
#Starting..
dataloader = get_dataloader("/content/Human-Segmentation-Dataset/Training_Images", "/content/Human-Segmentation-Dataset/Ground_Truth", batch_size=8, shuffle=True)

In [48]:
model = UNet(in_channels=3, num_classes=1)

In [50]:
train(model, dataloader, epochs=50, lr=0.001)

No previous model found, training from scratch
cuda
Epoch 1/50, Loss: 1.5347
Epoch 2/50, Loss: 1.5191
Epoch 3/50, Loss: 1.5144
Epoch 4/50, Loss: 1.5034
Epoch 5/50, Loss: 1.4972
Epoch 6/50, Loss: 1.4970
Epoch 7/50, Loss: 1.4876
Epoch 8/50, Loss: 1.4836
Epoch 9/50, Loss: 1.4904
Epoch 10/50, Loss: 1.4823
Epoch 11/50, Loss: 1.4810
Epoch 12/50, Loss: 1.4827
Epoch 13/50, Loss: 1.4798
Epoch 14/50, Loss: 1.4793
Epoch 15/50, Loss: 1.4771
Epoch 16/50, Loss: 1.4770
Epoch 17/50, Loss: 1.4773
Epoch 18/50, Loss: 1.4786
Epoch 19/50, Loss: 1.4775
Epoch 20/50, Loss: 1.4712
Epoch 21/50, Loss: 1.4703
Epoch 22/50, Loss: 1.4702
Epoch 23/50, Loss: 1.4712
Epoch 24/50, Loss: 1.4734
Epoch 25/50, Loss: 1.4664
Epoch 26/50, Loss: 1.4690
Epoch 27/50, Loss: 1.4749
Epoch 28/50, Loss: 1.4689
Epoch 29/50, Loss: 1.4685
Epoch 30/50, Loss: 1.4728
Epoch 31/50, Loss: 1.4774
Epoch 32/50, Loss: 1.4707
Epoch 33/50, Loss: 1.4662
Epoch 34/50, Loss: 1.4678
Epoch 35/50, Loss: 1.4743
Epoch 36/50, Loss: 1.4717
Epoch 37/50, Loss: 1.

In [53]:
#Inference on trained model
import numpy as np

#Load model and predict with stats
def predict(model_path, input_image_path):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

  #Load model
  model = UNet(in_channels=3, num_classes=1)
  model.load_state_dict(torch.load(model_path, map_location=device))
  model.to(device)
  model.eval()

  #Track start time
  total_start_time = time.time()

  #Image preprocessing
  preprocess_start_time = time.time()
  image = Image.open(input_image_path).convert("RGB")
  transform = transforms.Compose([
      transforms.Resize((512,512)),
      transforms.ToTensor()
  ])
  image_tensor = transform(image).unsqueeze(0).to(device)
  preprocess_end_time = time.time()

  #Model Inference
  inference_start_time = time.time()
  with torch.no_grad():
    output = model(image_tensor)
    output = torch.sigmoid(output)
  inference_end_time = time.time()

  #Postprocessing
  postprocess_start_time = time.time()
  mask = output.squeeze(0).squeeze(0).cpu().numpy()
  mask = (mask > 0.4).astype(np.uint8) * 255
  mask_image = Image.fromarray(mask)

  combined = Image.new("RGB", (512*2, 512))
  combined.paste(image.resize((512,512)), (0,0))
  combined.paste(mask_image.convert("RGB"), (512,0))
  combined.save("output.jpg")
  postprocess_end_time = time.time()

  #Calculate timing stats
  total_end_time = time.time()

  preprocess_time = preprocess_end_time - preprocess_start_time
  inference_time = inference_end_time - inference_start_time
  postprocess_time = postprocess_end_time - postprocess_start_time
  total_time = total_end_time - total_start_time

  #Print Stats
  print("Prediction Completed! Some Stats:")
  print(f"Image Preprocessing Time: {preprocess_time:.4f} seconds")
  print(f"Model Inference Time: {inference_time:.4f} seconds")
  print(f"Postprocessing Time: {postprocess_time:.4f} seconds")
  print(f"Total Prediction Time: {total_time:.4f} seconds")
  print("Prediction saved as output.jpg")

In [54]:
predict("/content/unet_model_final.pth", "/content/Human-Segmentation-Dataset/Training_Images/104.jpg")

Using device: cuda
Prediction Completed! Some Stats:
Image Preprocessing Time: 0.0085 seconds
Model Inference Time: 0.0034 seconds
Postprocessing Time: 0.1182 seconds
Total Prediction Time: 0.1300 seconds
Prediction saved as output.jpg
