# Import Libraries and set device

In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.segmentation as segmentation
import torch.optim as optim
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from pycocotools.coco import COCO
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF 
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


# Set seed for reproductability

In [2]:
seed_val = 15
random.seed(seed_val)               
np.random.seed(seed_val)           
torch.manual_seed(seed_val)      
os.environ['PYTHONHASHSEED'] = str(seed_val)
    
if torch.cuda.is_available():        
    torch.cuda.manual_seed(seed_val)
    cudnn.benchmark = True 
    cudnn.deterministic = False

# Function to resize image and mask to a fixed size (256x256) converting to Tensor.

In [3]:
def img_transform(image, mask):
    # Resize image to 256x256 
    image = TF.resize(image, (256, 256))
    # Convert image to Tensor [3, 256, 256]
    image = TF.to_tensor(image)
    # Use Nearest Neighbor for masks to preserve class IDs (0, 1, 2...)
    mask = TF.resize(mask, (256, 256), interpolation=Image.NEAREST)
    # Convert mask to LongTensor
    mask = torch.as_tensor(np.array(mask), dtype=torch.long)
    return image, mask

# Define the coco dataset class

In [4]:
class COCOSegmentationDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None, target_classes=None):
        self.root_dir = root_dir
        self.coco = COCO(ann_file)
        self.transform = transform
        
        # 1. Define Categories
        # If no specific classes are requested, use ALL categories
        if target_classes:
            self.cat_ids = self.coco.getCatIds(catNms=target_classes)
        else:
            self.cat_ids = sorted(self.coco.getCatIds()) # Sort ensures consistent ordering
            
        # 2. Create ID Mapping (Crucial for Multi-class)
        # We map COCO IDs (e.g., 1, 90) to contiguous Neural Net IDs (1, 80).
        # Background is always index 0.
        self.coco_id_to_contiguous = {coco_id: i + 1 for i, coco_id in enumerate(self.cat_ids)}
        
        # Store class names for visualization (Index 0 is 'background')
        coco_cats = self.coco.loadCats(self.cat_ids)
        self.classes = ['background'] + [c['name'] for c in coco_cats]

        # 3. Filter Images
        # Only keep images containing at least one of our target categories
        self.ids = []
        for cat_id in self.cat_ids:
            self.ids.extend(self.coco.getImgIds(catIds=[cat_id]))
        self.ids = list(set(self.ids)) # Remove duplicates

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        
        # Load image
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        
        # Load Annotations
        ann_ids = coco.getAnnIds(imgIds=img_id, catIds=self.cat_ids)
        anns = coco.loadAnns(ann_ids)
        
        # Create Mask (Background = 0)
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)

        for ann in anns:
            coco_id = ann['category_id']
            
            # Use our mapping to get the contiguous ID (1-80)
            if coco_id in self.coco_id_to_contiguous:
                pixel_value = self.coco_id_to_contiguous[coco_id]
                ann_mask = coco.annToMask(ann)
                
                # Overwrite pixels (Last annotation wins in overlap)
                mask[ann_mask > 0] = pixel_value

        mask = Image.fromarray(mask)

        if self.transform:
            image, mask = self.transform(image, mask)

        return image, mask

# ResNet architecture

In [10]:
# Hyperparameters
LEARNING_RATE = 1e-4 # Lower LR is better for fine-tuning pre-trained models
BATCH_SIZE = 4       # FCN-ResNet50 is memory hungry
NUM_EPOCHS = 10

In [11]:
print("Initializing Full COCO Dataset...")
data_path = './coco/train2017' 
ann_path = './coco/annotations/instances_train2017.json'

val_data_path = './coco/val2017'
val_ann_file = './coco/annotations/instances_val2017.json'

# Initialize coco train dataset 
train_dataset = COCOSegmentationDataset(
    root_dir=data_path, 
    ann_file=ann_path, 
    transform=img_transform,
    target_classes=None 
)

# Initialize coco validation dataset
val_dataset = COCOSegmentationDataset(
    root_dir=val_data_path, 
    ann_file=val_ann_file, 
    transform=img_transform,
    target_classes=None
)
NUM_CLASSES = len(train_dataset.classes)

# 1. Load Pretrained FCN with ResNet-50 Backbone
# weights='DEFAULT' loads the COCO-pretrained weights, which helps convergence
model = segmentation.fcn_resnet50(weights='DEFAULT')

# 2. Modify the Classifier Heads
# We need to replace the last layer to output NUM_CLASSES (e.g., 81)
# The model has a main classifier ('classifier') and an auxiliary one ('aux_classifier')

# -- Main Head --
# Input features of the classification layer (usually 2048 for ResNet50)
in_features = model.classifier[4].in_channels
model.classifier[4] = nn.Conv2d(in_features, NUM_CLASSES, kernel_size=1)

# -- Auxiliary Head (Helps with gradient flow) --
# Input features for aux head (usually 1024)
in_features_aux = model.aux_classifier[4].in_channels
model.aux_classifier[4] = nn.Conv2d(in_features_aux, NUM_CLASSES, kernel_size=1)

# 3. Move to GPU
model = model.to(device)

# Optimizer & Loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=255)

# Loaders (Re-using your existing dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"FCN-ResNet50 model ready with {NUM_CLASSES} output classes.")

Initializing Full COCO Dataset...
loading annotations into memory...
Done (t=16.72s)
creating index...
index created!
loading annotations into memory...
Done (t=0.56s)
creating index...
index created!
FCN-ResNet50 model ready with 81 output classes.


# Training loop

In [None]:
print("Starting FCN-ResNet50 Training...")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0
    
    # Progress bar with safe update interval
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", mininterval=5.0)
    
    for i, (images, masks) in enumerate(loop):
        images = images.to(device)
        masks = masks.to(device)
        
        # Safety Patch (Clamp labels)
        mask_max = NUM_CLASSES - 1
        masks[masks > mask_max] = 255
        
        optimizer.zero_grad()
        
        # === FORWARD PASS CHANGE ===
        # Model returns a dict: {'out': tensor, 'aux': tensor}
        output_dict = model(images)
        output = output_dict['out']
        aux_output = output_dict['aux']
        
        # Calculate Loss (Weighted sum of main + aux)
        loss_main = criterion(output, masks)
        loss_aux = criterion(aux_output, masks)
        loss = loss_main + 0.4 * loss_aux
        # ===========================
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 50 == 0:
            loop.set_postfix(loss=loss.item())
            
    avg_train_loss = running_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for v_img, v_mask in val_loader:
            v_img = v_img.to(device)
            v_mask = v_mask.to(device)
            v_mask[v_mask > mask_max] = 255
            
            # Val only needs 'out'
            val_out = model(v_img)['out']
            val_loss += criterion(val_out, v_mask).item()
            
    avg_val_loss = val_loss / len(val_loader)
    
    print(f"Epoch {epoch+1} Results | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    # Save (It will be large, ~100MB+)
    torch.save(model.state_dict(), f'fcn_resnet50_epoch_{epoch+1}.pth')

print("Training Complete.")

Starting FCN-ResNet50 Training...


Epoch 1/10: 100%|████████████████████████████████████████████████████| 29317/29317 [47:57<00:00, 10.19it/s, loss=0.571]


Epoch 1 Results | Train Loss: 1.0346 | Val Loss: 0.6243


Epoch 2/10: 100%|████████████████████████████████████████████████████| 29317/29317 [39:25<00:00, 12.40it/s, loss=0.767]


Epoch 2 Results | Train Loss: 0.8080 | Val Loss: 0.5964


Epoch 3/10:  42%|█████████████████████▉                              | 12369/29317 [14:17<19:54, 14.19it/s, loss=0.248]

Initialize weights and start training loop