# Implementing UNet Architecture From Scratch Using PyTorch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import torchvision as TV
import matplotlib.pyplot as plt
import os
from PIL import Image

# UNet Model Architecture:

In [None]:
# Define custom UNet model
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder Block
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.pool = self.MaxPool2d(kernel_size=2, stride=2)

        # Transposed Convolutions Block
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        # Decoder Block
        self.dec1 = self.conv_block(512, 256)
        self.dec2 = self.conv_block(256, 128)
        self.dec3 = self.conv_block(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)


    # Define Convolutions block for reusability
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    # Forward Method
    def forward(self, x):

        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))

        x = self.upconv3(x4)
        x = t.cat([x, x3], dim=1)
        x = self.dec1(x)

        x = self.upconv2(x)
        x = t.cat([x, x2], dim=1)
        x = self.dec2(x)

        x = self.upconv1(x)
        x = t.cat([x, x1], dim=1)
        x = self.dec3(x)
        
        return self.out(x)


# Custom Dataset Class:

In [None]:
class LungSegmentationDataset(t.utils.data.Dataset):
    def __init__(self, image_paths, mask_paths, image_transform=None, mask_transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_filenames = sorted(os.listdir(image_paths))
        self.mask_filenames = sorted(os.listdir(mask_paths))
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        image_dir = os.path.join(self.image_paths, self.image_filenames[idx])
        mask_dir = os.path.join(self.mask_paths, self.mask_filenames[idx])

        image = Image.open(image_dir).convert('RGB')
        mask = Image.open(mask_dir).convert('L')

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        mask = mask.squeeze(0)
        mask = t.where(mask > 0, 1, 0).float()

        return image, mask

# Image & Mask Transoformations

In [None]:
image_transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

# Load Data

In [None]:
images_path = "/content/drive/MyDrive/Chest X-Ray Dataset/images"
masks_path = "/content/drive/MyDrive/Chest X-Ray Dataset/mask"


dataset = LungSegmentationDataset(images_path, masks_path, image_transform=image_transform, mask_transform=mask_transform)

# Split data into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = t.utils.data.random_split(dataset, [train_size, val_size])

dataloader_train = t.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
dataloader_val = t.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

# Train Model

In [None]:
model = UNet(3, 1).cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Initialize lists to store losses
train_losses = []
val_losses = []

# Training and Validation Loop
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0

    for images, masks in dataloader_train:
        images = images.cuda()
        masks = masks.cuda()

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(dataloader_train)
    train_losses.append(train_loss)
    print(f"Epoch {epoch}/{num_epochs}, Training Loss: {train_loss}")

    # Validation Loop
    model.eval()
    val_loss = 0.0
    with t.no_grad():
        for images, masks in dataloader_val:
            images = images.cuda()
            masks = masks.cuda()
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    val_loss = val_loss / len(dataloader_val)
    val_losses.append(val_loss)
    print(f"Epoch {epoch}/{num_epochs}, Validation Loss: {val_loss}")

# Plot Training and Validation Losses
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()

# Test Model

In [None]:
def test_and_visualize(model, dataset, num_samples=5):
    model.eval()
    indices = t.randint(0, len(dataset), (num_samples,))
    plt.figure(figsize=(15, 10))

    for i, idx in enumerate(indices):
        image, true_mask = dataset[idx]
        image = image.unsqueeze(0).cuda()

        with t.no_grad():
            pred_mask = model(image)
            pred_mask = t.sigmoid(pred_mask).squeeze().cpu().numpy()

        pred_mask = (pred_mask > 0.5).astype(float)

        plt.subplot(num_samples, 4, i * 4 + 1)
        plt.imshow(image.squeeze().permute(1, 2, 0).cpu().numpy())
        plt.title("Original Image")
        plt.axis("off")

        plt.subplot(num_samples, 4, i * 4 + 2)
        plt.imshow(true_mask.squeeze(0).cpu(), cmap="gray")
        plt.title("True Mask")
        plt.axis("off")

        plt.subplot(num_samples, 4, i * 4 + 3)
        plt.imshow(pred_mask, cmap="gray")
        plt.title("Predicted Mask")
        plt.axis("off")

        plt.subplot(num_samples, 4, i * 4 + 4)
        plt.imshow(image.squeeze(0).permute(1, 2, 0).cpu())  # Original image
        plt.imshow(pred_mask, cmap="jet", alpha=0.5)  # Overlay mask
        plt.title("Overlay")
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
test_and_visualize(model, val_dataset)