In this second experiment, I look into buiding the minimum viable product for another use case: City segmentation for autonomous driving. One model is trained using Weak supervision while the other will use a fully supervised approach. Here code will be available to call the saved models and do some more advanced comparison between the two:

Pb: Downloader les datasets ca demande beaucoup d'espace de storage

## Weakly supervised approach

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Define the number of semantic classes.
# (Cityscapes is usually mapped to 19 classes; here we assume that the segmentation masks have already been mapped accordingly.)
num_classes = 19

# Custom dataset wrapper for weak supervision.
# It uses the Cityscapes 'fine' segmentation masks to derive multi-label targets.
class CityscapesWeak(Dataset):
    def __init__(self, root, split="train", mode="fine", target_type="semantic", transform=None):
        self.cityscapes = datasets.Cityscapes(root=root, split=split, mode=mode, target_type=target_type,
                                               transform=transform)
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.cityscapes)

    def __getitem__(self, idx):
        image, mask = self.cityscapes[idx]
        # Convert mask to a NumPy array.
        mask_np = np.array(mask)
        # Create a multi-label vector: for each class, mark 1 if it appears in the mask.
        labels = np.zeros(self.num_classes, dtype=np.float32)
        unique_labels = np.unique(mask_np)
        # Remove the ignore label (often 255) if present.
        unique_labels = unique_labels[unique_labels != 255]
        # For demonstration, assume unique_labels are in the range [0, num_classes-1].
        for lab in unique_labels:
            if lab < self.num_classes:
                labels[int(lab)] = 1.0
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.from_numpy(labels)

# Define image transform (resize and normalize).
transform_img = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create the weakly supervised dataset and DataLoader.
cityscapes_weak = CityscapesWeak(root='./data/cityscapes', split="train", mode="fine",
                                  target_type="semantic", transform=transform_img)
batch_size = 8
weak_loader = DataLoader(cityscapes_weak, batch_size=batch_size, shuffle=True, num_workers=4)

# Define a ResNet50 classification model.
model_ws = models.resnet50(pretrained=False)
# Replace the final fully connected layer to output logits for num_classes.
model_ws.fc = nn.Linear(model_ws.fc.in_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ws = model_ws.to(device)

# Use BCEWithLogitsLoss for multi-label classification.
criterion_ws = nn.BCEWithLogitsLoss()
optimizer_ws = optim.Adam(model_ws.parameters(), lr=1e-3)

# Training loop for weak supervision.
num_epochs = 5  # Adjust as needed.
model_ws.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in weak_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer_ws.zero_grad()
        outputs = model_ws(images)  # outputs: (B, num_classes)
        loss = criterion_ws(outputs, labels)
        loss.backward()
        optimizer_ws.step()
        running_loss += loss.item()
    print(f"Weak Sup Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(weak_loader):.4f}")

# Save the trained weakly supervised model.
torch.save(model_ws.state_dict(), "cityscapes_weak_classification.pth")
print("Weakly supervised model trained and saved.")

# --- Generating CAM from the Trained Weakly Supervised Model ---

# Global variable to store feature maps.
feature_maps_ws = None
def hook_feature_ws(module, input, output):
    global feature_maps_ws
    feature_maps_ws = output.detach()

# Register the hook on layer4 of ResNet50.
model_ws.layer4.register_forward_hook(hook_feature_ws)

def generate_cam_ws(model, img_tensor, target_class):
    global feature_maps_ws
    model.eval()
    with torch.no_grad():
        _ = model(img_tensor.to(device))
    # Obtain the final FC layer weights.
    fc_weights = model.fc.weight.data.cpu().numpy()
    fmap = feature_maps_ws.cpu().numpy()[0]  # shape: (C, H, W)
    cam = np.zeros(fmap.shape[1:], dtype=np.float32)
    for i, w in enumerate(fc_weights[target_class]):
        cam += w * fmap[i, :, :]
    cam = np.maximum(cam, 0)
    # Resize CAM to input image dimensions.
    import cv2  # Ensure opencv-python is installed.
    cam = cv2.resize(cam, (img_tensor.size(3), img_tensor.size(2)))
    cam -= np.min(cam)
    if np.max(cam) != 0:
        cam /= np.max(cam)
    return cam

# Visualize CAM for a sample image.
sample_img, sample_labels = cityscapes_weak[0]
sample_tensor = sample_img.unsqueeze(0)
# For demonstration, pick a target class (e.g., class 0). In practice, choose based on predicted output.
target_class = 0
cam_output = generate_cam_ws(model_ws, sample_tensor, target_class)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
# Display the sample image (convert tensor to numpy image).
img_disp = sample_img.permute(1,2,0).cpu().numpy()
plt.imshow(img_disp)
plt.title("Sample Image")
plt.axis("off")
plt.subplot(1,2,2)
plt.imshow(cam_output, cmap='jet')
plt.title("CAM for Class 0")
plt.axis("off")
plt.show()


## Fully supervised approach

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Define image transformation for Cityscapes images.
transform_img = transforms.Compose([
    transforms.Resize((512, 256)),  # Cityscapes images are high resolution; adjust as needed.
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Define a target transformation for segmentation masks.
def mask_transform(mask):
    # Resize the mask using nearest neighbor to preserve label integrity.
    mask = mask.resize((512, 256), resample=Image.NEAREST)
    mask_np = np.array(mask).astype(np.int64)
    # Note: In the Cityscapes fine annotations, labels may need remapping to 19 classes.
    # Here we assume the masks are already mapped appropriately (values in 0..18).
    return torch.from_numpy(mask_np)

# Create the Cityscapes dataset for semantic segmentation.
cityscapes_dataset = datasets.Cityscapes(root='./data/cityscapes', split='train', mode='fine',
                                          target_type='semantic', transform=transform_img,
                                          target_transform=mask_transform)
batch_size = 4
cityscapes_loader = DataLoader(cityscapes_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Define the DeepLabV3 segmentation model.
model_fs = models.segmentation.deeplabv3_resnet50(pretrained=True)
# Replace the classifier head to output 19 classes.
model_fs.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, 19)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_fs = model_fs.to(device)

# Define the loss and optimizer.
criterion_fs = nn.CrossEntropyLoss(ignore_index=255)  # 255 is often used as the ignore label in Cityscapes.
optimizer_fs = optim.Adam(model_fs.parameters(), lr=1e-4)
num_epochs = 10  # Adjust epochs as needed.

def train_one_epoch_fs(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)['out']  # DeepLabV3 returns a dict; 'out' contains the segmentation logits.
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

for epoch in range(num_epochs):
    epoch_loss = train_one_epoch_fs(model_fs, cityscapes_loader, optimizer_fs, criterion_fs, device)
    print(f"Fully Sup Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}")

# Save the fully supervised model.
torch.save(model_fs.state_dict(), "cityscapes_fully_supervised.pth")
print("Fully supervised model trained and saved.")

# --- Visualization of a Segmentation Prediction ---
def visualize_fs_prediction(model, dataset, index=0):
    model.eval()
    image, true_mask = dataset[index]
    image_batch = image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image_batch)['out']
    pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
    # Convert image back to displayable format (undo normalization).
    inv_norm = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    image_disp = image.permute(1,2,0).cpu().numpy()
    fig, axs = plt.subplots(1, 3, figsize=(15,5))
    axs[0].imshow(image_disp)
    axs[0].set_title("Input Image")
    axs[0].axis("off")
    axs[1].imshow(true_mask, cmap='gray')
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")
    axs[2].imshow(pred_mask, cmap='gray')
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()

# Visualize a prediction on a sample image.
visualize_fs_prediction(model_fs, cityscapes_dataset, index=0)
