In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Organize Imports

In [None]:
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

In [None]:
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import torchvision.transforms.functional as TF

In [None]:
import lightning as pl

# Section A: UNet Segmentation Training with Pretrained ResNet34 Backbone

In this section we will train a UNet model for semantic segmentation. The encoder uses a pretrained ResNet34 backbone. We will use the Oxford-IIIT Pet dataset (downloaded via TorchVision) and train our UNet model using PyTorch Lightning.

In [None]:
# Define transformation pipeline for the dataset
transform = Compose([
    Resize((128, 128)),            # Resize images to 128x128
    ToTensor(),                    # Convert PIL image to tensor
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_trf = Compose([
    Resize((128, 128)),            # Resize images to 128x128
    ToTensor(),                    # Convert PIL image to tensor
])

# Download and load the Oxford-IIIT Pet dataset (for segmentation)
pet_dataset = OxfordIIITPet(
    root="oxford-iiit-pet", 
    download=True, 
    target_types="segmentation", 
    transform=transform,
    target_transform=target_trf
)
print("Oxford-IIIT Pet dataset loaded:", len(pet_dataset), "samples")

In [None]:
img, trg = pet_dataset[0]

In [None]:
img.shape, trg.shape

In [None]:
class PretrainedUNet(nn.Module):
    def __init__(self, n_class=1):
        super().__init__()
        # Load a pretrained ResNet34 model
        resnet = models.resnet34(models.resnet.ResNet34_Weights.DEFAULT)
        
        # comment: Use early layers as the encoder
        self.encoder0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # [B, 64, H/2, W/2]
        self.pool0 = resnet.maxpool  # Reduces spatial size by factor of 2
        self.encoder1 = resnet.layer1  # [B, 64, H/4, W/4]
        self.encoder2 = resnet.layer2  # [B, 128, H/8, W/8]
        self.encoder3 = resnet.layer3  # [B, 256, H/16, W/16]
        self.encoder4 = resnet.layer4  # [B, 512, H/32, W/32]
        
        # comment: Decoder layers with upsampling and skip connections
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = self.double_conv(512, 256)  # Concatenate with encoder3
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = self.double_conv(256, 128)  # Concatenate with encoder2
        
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = self.double_conv(128, 64)   # Concatenate with encoder1
        
        self.upconv1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.decoder1 = self.double_conv(128, 64)     # Concatenate with encoder0
        
        self.out_conv = nn.Conv2d(64, n_class, kernel_size=1)  # Output layer
        
    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        x0 = self.encoder0(x)        
        x1 = self.pool0(x0)          
        x1 = self.encoder1(x1)       
        x2 = self.encoder2(x1)       
        x3 = self.encoder3(x2)       
        x4 = self.encoder4(x3)       
        
        # comment: Bottleneck is implicit in x4
        
        # comment: Decoder
        d4 = self.upconv4(x4)        
        d4 = torch.cat([d4, x3], dim=1)  
        d4 = self.decoder4(d4)       
        
        d3 = self.upconv3(d4)        
        d3 = torch.cat([d3, x2], dim=1)  
        d3 = self.decoder3(d3)       
        
        d2 = self.upconv2(d3)        
        d2 = torch.cat([d2, x1], dim=1)  
        d2 = self.decoder2(d2)       
        
        d1 = self.upconv1(d2)        
        d1 = torch.cat([d1, x0], dim=1)  
        d1 = self.decoder1(d1)       
        
        out = self.out_conv(d1)
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        
        return out

# Instantiate UNet with 1 output channel (binary segmentation)
unet_model = PretrainedUNet(n_class=1)
print(unet_model)

In [None]:
class UNetLightning(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super(UNetLightning, self).__init__()
        self.model = unet_model  # comment: Using the UNet defined above
        self.lr = lr
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        imgs = imgs.to(self.device)
        # Assuming masks are already transformed to tensors
        masks = masks.to(self.device).float()
        outputs = self.model(imgs)
        loss = self.criterion(outputs, masks)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        imgs = imgs.to(self.device)
        masks = masks.to(self.device).float()
        outputs = self.model(imgs)
        loss = self.criterion(outputs, masks)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer

# Instantiate the Lightning module for UNet training
unet_lightning = UNetLightning(lr=1e-3)

In [None]:
# Split the Oxford-IIIT Pet dataset into train and validation subsets
train_size = int(0.8 * len(pet_dataset))
val_size = len(pet_dataset) - train_size
train_dataset, val_dataset = random_split(pet_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

In [None]:
# Create a PyTorch Lightning Trainer for UNet
trainer = pl.Trainer(
    max_epochs=5, 
    accelerator='auto', 
    devices=1
)

In [None]:
train_loader.dataset[0]

In [None]:
# Train the UNet model
trainer.fit(unet_lightning, train_loader, val_loader)

In [None]:
unet_lightning.model.eval()
# Get one batch from the validation loader
batch = next(iter(val_loader))
imgs, true_masks = batch
imgs = imgs.to(unet_lightning.device)

with torch.no_grad():
    preds = unet_lightning.model(imgs)

# Visualize the first image's prediction
img_tensor = imgs[0].cpu()
true_mask = true_masks[0].cpu().squeeze()  # comment: Assuming single channel mask
pred_mask = torch.sigmoid(preds[0]).cpu().squeeze()
pred_mask_bin = (pred_mask > 0.5).float()

plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(TF.to_pil_image(img_tensor))
plt.title("Input Image")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(true_mask, cmap='gray')
plt.title("True Mask")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(pred_mask_bin, cmap='gray')
plt.title("Predicted Mask")
plt.axis("off")
plt.show()

# Section B: Mask-RCNN Instance Segmentation Training with PyTorch Lightning

In this section we will train a Mask-RCNN model on the PennFudanPed dataset. We download and extract the dataset, define a custom Dataset class and DataLoaders, then wrap TorchVision’s Mask-RCNN in a PyTorch Lightning module for training and evaluation.

In [None]:
import requests, zipfile, io

url = "https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip"
response = requests.get(url)
if response.status_code == 200:
    with zipfile.ZipFile(io.BytesIO(response.content)) as z:
        z.extractall("PennFudanPed")
    print("PennFudanPed dataset downloaded and extracted to './PennFudanPed'")
else:
    print("Download failed with status code:", response.status_code)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, random_split

class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        mask = np.array(mask)
        obj_ids = np.unique(mask)[1:]  # comment: Remove background
        masks = mask == obj_ids[:, None, None]
        boxes = []
        for i in range(len(obj_ids)):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        num_objs = len(obj_ids)
        labels = torch.ones((num_objs,), dtype=torch.int64)  # comment: one class (person)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        if self.transforms is not None:
            img = self.transforms(img)
        return img, target
    
    def __len__(self):
        return len(self.imgs)

dataset_full = PennFudanDataset("PennFudanPed", transforms=None)
n = len(dataset_full)
n_train = int(0.8 * n)
n_val = n - n_train
dataset_train, dataset_val = random_split(dataset_full, [n_train, n_val])

train_loader = DataLoader(dataset_train, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

In [None]:
import pytorch_lightning as pl
import torchvision.models.detection as detection

class MaskRCNNLightning(pl.LightningModule):
    def __init__(self, num_classes=2, lr=0.005):
        super().__init__()
        # Load pre-trained Mask-RCNN
        self.model = detection.maskrcnn_resnet50_fpn(pretrained=True)
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
        self.lr = lr

    def forward(self, images, targets=None):
        return self.model(images, targets)
    
    def training_step(self, batch, batch_idx):
        images, targets = batch
        images = [img.to(self.device) for img in images]
        targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
        loss_dict = self.model(images, targets)
        loss = sum(loss.mean() for loss in loss_dict.values())
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, targets = batch
        images = [img.to(self.device) for img in images]
        targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
        loss_dict = self.model(images, targets)
        loss = sum(loss.mean() for loss in loss_dict.values())
        self.log("val_loss", loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=0.0005)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
        return [optimizer], [scheduler]

maskrcnn_model = MaskRCNNLightning(num_classes=2, lr=0.005)

In [None]:
import pytorch_lightning as pl

trainer = pl.Trainer(max_epochs=3, accelerator="auto", devices=1)
trainer.fit(maskrcnn_model, train_loader, val_loader)

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
from torchvision.utils import draw_bounding_boxes

maskrcnn_model.model.eval()
batch = next(iter(val_loader))
images, targets = batch
images = [img.to(maskrcnn_model.device) for img in images]

with torch.no_grad():
    outputs = maskrcnn_model.model(images)

# Visualize predictions for the first image in the batch
img = images[0].cpu()
img_vis = (img * 255).type(torch.uint8)
scores = outputs[0]["scores"].detach().cpu()
keep = scores >= 0.5
boxes = outputs[0]["boxes"][keep].detach().cpu()

drawn_img = draw_bounding_boxes(img_vis, boxes, colors="red", width=2)

plt.figure(figsize=(8,8))
plt.imshow(TF.to_pil_image(drawn_img))
plt.axis("off")
plt.title("Mask-RCNN Validation Predictions")
plt.show()