In [22]:
import torch
import numpy as np
import pandas as pd
import os
import time

from PIL import Image
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam, AdamW, SGD

In [23]:
# !git clone https://github.com/VikramShenoy97/Human-Segmentation-Dataset.git

In [24]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize((512,512)),
            transforms.ToTesnsor()
        ])

        valid_extension = {".jpg",".jpeg",".png"}
        self.images = [f for f in os.listdir(image_dir) if os.path.splittext(f)[1].lower() in valid_extension]

    def __len(self):
        return len(self.images)

    def __getitem(self, idx):
        image_path= os.path.join(self.image_dir, self.images[idx])
        name, ext = os.path.splitext(self.images[idx])
        mask_path= os.path.join(self.mask_dir, f"{name}.png")

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        mask = (mask >0.5).float()

        return image, mask


In [25]:
## Data loader 

def get_dataloader(image_dir, maks_dir, batch_size =2, shuffle =True):
    dataset = SegmentationDataset(image_dir, maks_dir)
    return DataLoader(dataset, batch_size = batch_size, shuffle =shuffle)


In [26]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding =1),
            nn.ReLU(inplace =True),
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding =1),
            nn.ReLU(inplace =True)
        )
    def forward(self, x):
        return self.conv_op(x)
        

In [27]:
## this downsample need to take care of as two conv and one max pooling

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool =nn.MaxPool2d(kernel_size =2, stride =2)
    def forward(self, x):
        down = self.conv(x) ## convolve, this need be saved for upsampling
        p = self.pool(down) ## the pooling stuff
        return down, p

In [28]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size =2, stride =2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2],1)
        return self.conv(x)
        

In [29]:
### Unet

class Unet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_conv1= DownSample(in_channels, 64) ## output is 64
        self.down_conv2= DownSample(64, 128)
        self.down_conv3= DownSample(128,256)
        self.down_conv4= DownSample(256,512) 

        self.bottle_neck = DoubleConv(512, 1024)
        
        self.up_conv1= UpSample(1024, 512) ## output is 64
        self.up_conv2= UpSample(512, 256)
        self.up_conv3= UpSample(256,128)
        self.up_conv4= UpSample(128,64)

        self.out =nn.Conv2d(in_channels = 64, out_channels= num_classes, kernel_size =1)

    def forward(self,x):
        down_1, p1 = self.down_conv1(x)
        down_2, p2 = self.down_conv2(p1)
        down_3, p3 = self.down_conv3(p2)
        down_4, p4 = self.down_conv4(p3)

        b = self.bottle_neck(p4)

        up1= self.up_conv1(b,down_4)
        up2= self.up_conv2(up1,down_3)
        up3= self.up_conv3(up2,down_2)
        up4= self.up_conv4(up3,down_1)

        out =self.out(up4)
        return out
        

        

In [30]:
## 

class DiceLoss(nn.Module):
    def __init__(self, smooth = 1e-16):
        super(DiceLoss, self).__init__()
        self.smooth =smooth
    def forward(self,inputs, targets):
        inputs = inputs.view(-1)
        targets =targets.view(-1)

        intersection =(inputs * targets).sum()
        dice_score = (2 * intersection + self.smooth)/(inputs.sum() + targets.sum() + self.smooth)

        return 1 - dice_score



class BCEWithDiceLoss(nn.Module): 
    def __init__(self, smooth =1e-6):
        super(BCEWithDiceLoss, self).__init__()
        self.bce =nn.BCEWithDiceLoss()
        self.dice = DiceLoss()

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        dice_loss =self.dice(inputs, targets)
        return 0.5 * bce_loss + dice_loss
    

In [31]:
## Training loop  Loss

def train(model, dataloader, epochs =2, lr =0.001, save_path  ="unet_model", load_path = None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if load_path and os.path.exists(load_path):
        print(f"Loading model weights  form{load_path}")
        model.load_state_dict(torch.load(load_path, map_location =device))
    else : 
        print(f"no checkpoint found, training from scratch..")

    print(device)
    model.to(device)

    criterion = BCEWithDiceLoss()

    optimizer = SGD(model.parameters(), lr =lr)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for images, masks, in dataloader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()

            out = model(images)
            
            loss =criterion(output, masks)
            loss.backward()
            optimizer.steps()

            epoch_loss +=loss.item()
        
        avg_loss =epoch_loss/ len(dataloader)

        print(f"epoch [{epoch+1}/{epochs}], Loss : {avg_loss: .4f}, LR: {lr}")

        if epoch %10==0 and epoch>0:
            torch.save(model.state_dict(), f"{save_path}.pth")

    torch.save(model.state_dict(), f"{save_path}_final.pth")
    print(f"Model_ saved to {save_path}")

        

In [32]:
dataloader = get_dataloader("", "", batch_size =8, shuffle=True)

model = Unet(in_channels=3, num_classes=1)

AttributeError: module 'torchvision.transforms' has no attribute 'ToTesnsor'

In [None]:
train(model, dataloader, epochs =2, lr =0.001)

### Inference on trained model

In [33]:
import numpy as np

# Load model and predict with stats
def predict(model_path, input_image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load model
    model = UNet(in_channels=3, num_classes=1)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Track start time
    total_start_time = time.time()

    # Image preprocessing
    preprocess_start_time = time.time()
    image = Image.open(input_image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)
    preprocess_end_time = time.time()

    # Model inference
    inference_start_time = time.time()
    with torch.no_grad():
        output = model(image_tensor)
        output = torch.sigmoid(output)
    inference_end_time = time.time()

    # Postprocessing
    postprocess_start_time = time.time()
    mask = output.squeeze(0).squeeze(0).cpu().numpy()
    mask = (mask > 0.4).astype(np.uint8) * 255
    mask_image = Image.fromarray(mask)

    combined = Image.new("RGB", (512 * 2, 512))
    combined.paste(image.resize((512, 512)), (0, 0))
    combined.paste(mask_image.convert("RGB"), (512, 0))
    combined.save("output.jpg")
    postprocess_end_time = time.time()

    # Calculate timing stats
    total_end_time = time.time()

    preprocess_time = preprocess_end_time - preprocess_start_time
    inference_time = inference_end_time - inference_start_time
    postprocess_time = postprocess_end_time - postprocess_start_time
    total_time = total_end_time - total_start_time

    # Print stats
    print("\nPrediction completed! Stats:")
    print(f"  Image Preprocessing Time: {preprocess_time:.4f} seconds")
    print(f"  Model Inference Time: {inference_time:.4f} seconds")
    print(f"  Postprocessing Time: {postprocess_time:.4f} seconds")
    print(f"  Total Prediction Time: {total_time:.4f} seconds")
    print("Prediction saved as output.jpg")


In [None]:
predict(model_path="/content/unet_model_80.pth", input_image_path="/content/Human-Segmentation-Dataset/Training_Images/5.jpg")