In [None]:
import torch 
import torch.nn as nn

import torchvision.transforms as transforms
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import InterpolationMode

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import time 
import os 

import pandas as pd 

In [None]:
# class PetDataset(torch.utils.data.Dataset):
#     def __init__(self, root, split="train"):
#         self.dataset = datasets.OxfordIIITPet(
#             root=root,
#             target_types='segmentation', 
#             download=True
#         )
#         self.split = split
#         # self.rotation = transforms.RandomChoice([
#         #     transforms.RandomRotation((0, 0), interpolation=InterpolationMode.BILINEAR),
#         #     transforms.RandomRotation((90, 90), interpolation=InterpolationMode.BILINEAR),
#         #     transforms.RandomRotation((180, 180), interpolation=InterpolationMode.BILINEAR),
#         #     transforms.RandomRotation((270, 270), interpolation=InterpolationMode.BILINEAR)
#         # ])

#         # Image transforms
#         self.img_transform = transforms.Compose([
#             transforms.Resize((128, 128)),
#             # transforms.RandomHorizontalFlip(p=0.5),
#             # transforms.RandomVerticalFlip(p=0.5),
#             # self.rotation,
#             transforms.ToTensor()
#         ])
#     def __getitem__(self, idx):
#         img, mask = self.dataset[idx]
#         img = self.img_transform(img)
#         # mask = self.img_transform(mask)
#         mask = mask.resize((128, 128), Image.NEAREST)
#         mask = torch.from_numpy(np.array(mask)).long()

#         mask[mask == 2] = 0
#         mask[mask == 3] = 0
#         mask[mask == 1] = 1  # Pet


#         return img, mask

#     def __len__(self):
#         return len(self.dataset)

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
import random
import numpy as np
from PIL import Image


class PetDataset(Dataset):
    def __init__(self, root, split="train"):
        self.dataset = datasets.OxfordIIITPet(
            root=root,
            target_types='segmentation',
            download=True
        )
        self.split = split

        # Only resizing and ToTensor are safely reusable for images
        self.resize = transforms.Resize((128, 128), interpolation=InterpolationMode.BILINEAR)
        self.resize_mask = transforms.Resize((128, 128), interpolation=InterpolationMode.NEAREST)

    def __getitem__(self, idx):
        img, mask = self.dataset[idx]

        # --------- Resize ---------------------------------
        img = self.resize(img)
        mask = self.resize_mask(mask)
        if self.split == "train":
            # --------- Random Horizontal Flip ------------------
            if random.random() < 0.5:
                img = F.hflip(img)
                mask = F.hflip(mask)
    
            # --------- Random Vertical Flip --------------------
            if random.random() < 0.5:
                img = F.vflip(img)
                mask = F.vflip(mask)
    
            # --------- Random 90-degree Rotations --------------
            rotations = [0, 90, 180, 270]
            angle = random.choice(rotations)
            img = F.rotate(img, angle, interpolation=InterpolationMode.BILINEAR)
            mask = F.rotate(mask, angle, interpolation=InterpolationMode.NEAREST)

        # --------- Convert types ---------------------------
        img = F.to_tensor(img)

        mask = torch.from_numpy(np.array(mask)).long()

        mask[mask == 2] = 0
        mask[mask == 3] = 0
        mask[mask == 1] = 1

        return img, mask

    def __len__(self):
        return len(self.dataset)


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), 
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), 
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# class UNet(nn.Module):
#     def __init__(self, n_classes=1):
#         super().__init__()
#         self.d1 = DoubleConv(3, 64)
#         self.d2 = DoubleConv(64, 128)
#         self.d3 = DoubleConv(128, 256)
#         self.u1 = DoubleConv(256+128, 128)
#         self.u2 = DoubleConv(128+64, 64)
#         self.out_conv = nn.Conv2d(64, 1, 1)
#         self.pool = nn.MaxPool2d(2)
#         self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
#         # Add weight initialization
#         self._initialize_weights()

#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)
#             elif isinstance(m, nn.BatchNorm2d):
#                 nn.init.constant_(m.weight, 1)
#                 nn.init.constant_(m.bias, 0)

#     def forward(self, x):
#         x1 = self.d1(x)
#         x2 = self.d2(self.pool(x1))
#         x3 = self.d3(self.pool(x2))
#         x = self.up(x3)
#         x = self.u1(torch.cat([x, x2], dim=1))
#         x = self.up(x)
#         x = self.u2(torch.cat([x, x1], dim=1))
#         out = self.out_conv(x)
#         return out

In [None]:
class UNet(nn.Module):
    def __init__(self, n_classes=1):
        super().__init__()
        self.d1 = DoubleConv(3, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.u1 = DoubleConv(256+128, 128)
        self.u2 = DoubleConv(128+64, 64)
        self.out_conv = nn.Conv2d(64, 1, 1)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor = 2, mode = "bilinear", align_corners=True)
        
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(self.pool(x1))
        x3 = self.d3(self.pool(x2))
        x = self.up(x3)
        x = self.u1(torch.cat([x, x2], dim=1))
        x = self.up(x)
        x = self.u2(torch.cat([x, x1], dim=1))
        out = self.out_conv(x)
        return out

In [None]:
train_ds = PetDataset("./data", split="train")
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)

val_ds = PetDataset("./data", split="val")
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)


# Exploratory data analysis on the image(EDA)

In [None]:
os.listdir("/kaggle/working/data/oxford-iiit-pet/images")[:10]

In [None]:
img = Image.open("/kaggle/working/data/oxford-iiit-pet/images/wheaten_terrier_14.jpg")

In [None]:
print(f"This is the size of the original image: {img.size}")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(img)
plt.axis('off')
plt.show()

## Resize image to see the shape of the input to the network

In [None]:
transformations = transforms.Compose([
    transforms.Resize((128, 128))
])

In [None]:
img_reshaped = transformations(img)

In [None]:
print(f"This is the size of the resized image: {img_reshaped.size}")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(img_reshaped)
plt.axis('off')
plt.show()

# EDA on the labels

In [None]:
with open("/kaggle/working/data/oxford-iiit-pet/annotations/list.txt", "r") as f:
    content = f.read()

In [None]:
content[:100]

In [None]:
with open("/kaggle/working/data/oxford-iiit-pet/annotations/test.txt", "r") as f:
    content = f.read()

In [None]:
content[:100]

In [None]:
with open("/kaggle/working/data/oxford-iiit-pet/annotations/trainval.txt", "r") as f:
    content = f.read()

In [None]:
content[:100]

In [None]:
os.listdir("/kaggle/working/data/oxford-iiit-pet/annotations/trimaps")[:10]

In [None]:
img = Image.open("/kaggle/working/data/oxford-iiit-pet/annotations/trimaps/British_Shorthair_165.png")

In [None]:
img.size

In [None]:
plt.imshow(img)
plt.axis('off')
plt.show()

## Inspect the image values

In [None]:
img_array = np.array(img)

In [None]:
np.unique(img_array)

In [None]:
sum_1 = np.sum(img_array[img_array==1])
sum_2 = np.sum(img_array[img_array==2])
sum_3 = np.sum(img_array[img_array==3])

In [None]:
mask_flat = img_array.flatten()

# Plot histogram
plt.hist(mask_flat, bins=np.arange(5)-0.5, rwidth=0.8)
plt.xticks([ 1, 2, 3])  # possible mask values
plt.xlabel("Mask Value")
plt.ylabel("Number of pixels")
plt.title("Histogram of mask pixel values")
plt.show()

### Viusalize only background, boarder and pet

In [None]:
img_array.shape

In [None]:
mask_1 = (img_array == 1)

plt.imshow(mask_1, cmap='gray')
plt.axis('off')
plt.title("Pixels with value 1")
plt.show()

In [None]:
mask_2 = (img_array == 2)

plt.imshow(mask_2, cmap='gray')
plt.axis('off')
plt.title("Pixels with value 2")
plt.show()

In [None]:
mask_3 = (img_array == 3)

plt.imshow(mask_3, cmap='gray')
plt.axis('off')
plt.title("Pixels with value 3")
plt.show()

### So 1 belongs to the pet itself!!
### 2 belongs to background 
### 3 belings to baorder

In [None]:
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        smooth = 1.0
        intersection = (pred * target).sum()
        dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        return 1.0 - dice

# Use this instead of BCE
criterion = DiceLoss()

In [None]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
model = UNet().to(device)
epochs = 30
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        masks = masks.float().unsqueeze(1)  # Shape [B,1,H,W]
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

In [None]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            masks = masks.float().unsqueeze(1)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
best_val_loss = float('inf')
for epoch in range(epochs):
    t0 = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)
   
    if val_loss < best_val_loss:
        print("Best Result! Model will be saved")
        best_val_loss = val_loss
        torch.save(model.state_dict(), "model.pt")

        
    t1 = time.time()
    elapsed_time = t1 - t0
    print(f"Epoch {epoch+1:02d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Time: {elapsed_time:.2f}s")
    

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = UNet().to(device)
model.load_state_dict(torch.load("model.pt"))

In [None]:
model.eval()
with torch.no_grad():
    for imgs, masks in train_loader:

        imgs, masks = imgs.to(device), masks.to(device)
        preds = model(imgs)
        preds = torch.sigmoid(preds)
        preds = (preds > 0.5).float().cpu()

        # Visualize results
        img = imgs[10].cpu().numpy().transpose(1, 2, 0)
        mask = masks[10].cpu().squeeze().numpy()
        pred = preds[10].cpu().squeeze().numpy()
        
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(img)
        plt.title("Input Image")
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(pred, cmap='gray')
        plt.title("Prediction")
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        break