<a href="https://colab.research.google.com/github/Kristina-26/DEEP-LEARNING-TASK-1-Kazlauskaite/blob/main/DL_task1_Kazlauskaite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Author: Kristina Kazlauskaitė

LSP: S2416112

Data Science (full-time studies), group 1

In [14]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm import tqdm

In [15]:
NUM_CLASSES = 4  # 'background', 'aeroplane', 'sofa', 'dog'
CLASSES = ['background', 'aeroplane', 'sofa', 'dog']
COLORS = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]  # For visualization
EPOCHS = 15
BATCH_SIZE = 2
LEARNING_RATE = 0.00005

In [16]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cpu


In [17]:
# DeepLabv3 model with ResNet50 backbone for semantic segmentation
class SimpleSegmentationModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(SimpleSegmentationModel, self).__init__()
        self.model = torchvision.models.segmentation.deeplabv3_resnet50(weights=torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT)
        self.model.classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=1) # classifier layer with our number of classes (4)

    def forward(self, x):
        return self.model(x)['out']

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], # using the mean and SD of ImageNet dataset
                         std=[0.229, 0.224, 0.225])
])

In [18]:
def mask_transform(mask):
    mask = mask.resize((256, 256), resample=Image.NEAREST)  # resize mask safely
    mask = np.array(mask)

    # create a new mask with our classes
    new_mask = np.zeros_like(mask)

    # VOC class IDs (0-indexed):
    # aeroplane: 1, sofa: 18, dog: 12, background: 0
    if 1 in mask:
        print("Found aeroplane in mask")
    if 18 in mask:
        print("Found sofa in mask")
    if 12 in mask:
        print("Found dog in mask")

    new_mask[mask == 1] = 1  # aeroplane
    new_mask[mask == 18] = 2  # sofa
    new_mask[mask == 12] = 3  # dog

    return torch.tensor(new_mask, dtype=torch.long)

In [19]:
# filter VOC to only have the selected classes (aeroplane, sofa and dog)
class VOCMultiClassFiltered(Dataset):
    def __init__(self, root, year='2012', image_set='train', transform=None, target_transform=None):
        self.voc = VOCSegmentation(root=root, year=year, image_set=image_set, download=True) # root: root directory,
        # year: VOC dataset year, image_set: 'train'/ 'val'
        self.transform = transform # image transformations
        self.target_transform = target_transform # mask transformations

        # checking if masks contain these classes
        self.valid_indices = []
        print("Checking for valid images with classes of interest")
        for idx in range(min(3000, len(self.voc))):  # checking not more than 3000 images
            _, target = self.voc[idx]
            mask = np.array(target)
            # looking for aeroplane (0), sofa (17), or dog (11)
            if 1 in mask or 18 in mask or 12 in mask:
                self.valid_indices.append(idx)
                if len(self.valid_indices) % 10 == 0: # log message every 10th time valid image is found
                    print(f"Found {len(self.valid_indices)} valid images so far")

        print(f"Found {len(self.valid_indices)} images with classes of interest")

        # class distribution in valid images
        class_counts = {1: 0, 18: 0, 12: 0}
        for idx in self.valid_indices:
            _, target = self.voc[idx]
            mask = np.array(target)
            for class_id in class_counts.keys():
                if class_id in mask:
                    class_counts[class_id] += 1

        print("Class distribution in valid images:")
        print(f"Aeroplane (1): {class_counts[1]} images")
        print(f"Sofa (18): {class_counts[18]} images")
        print(f"Dog (12): {class_counts[12]} images")

    def __len__(self):
        return len(self.voc)

    def __getitem__(self, idx):
        image, target = self.voc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        return image, target

In [20]:
dataset = VOCMultiClassFiltered(root='./data', transform=transform, target_transform=mask_transform)

# use valid indices that contain our classes, or take a subset if we have enough
if len(dataset.valid_indices) > 500:
    # at least 1000 images
    subset_size = min(1000, len(dataset.valid_indices))
    subset = Subset(dataset, dataset.valid_indices[:subset_size])
else:
    subset = Subset(dataset, range(1000))

Checking for valid images with classes of interest
Found 10 valid images so far
Found 20 valid images so far
Found 30 valid images so far
Found 40 valid images so far
Found 50 valid images so far
Found 60 valid images so far
Found 70 valid images so far
Found 80 valid images so far
Found 90 valid images so far
Found 100 valid images so far
Found 110 valid images so far
Found 120 valid images so far
Found 130 valid images so far
Found 140 valid images so far
Found 150 valid images so far
Found 160 valid images so far
Found 170 valid images so far
Found 180 valid images so far
Found 190 valid images so far
Found 200 valid images so far
Found 210 valid images so far
Found 220 valid images so far
Found 230 valid images so far
Found 240 valid images so far
Found 250 valid images so far
Found 260 valid images so far
Found 270 valid images so far
Found 280 valid images so far
Found 290 valid images so far
Found 291 images with classes of interest
Class distribution in valid images:
Aeroplane 

In [21]:
train_size = int(0.8 * len(subset))
val_size = len(subset) - train_size
train_dataset, val_dataset = random_split(subset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [22]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cpu


In [23]:
model = SimpleSegmentationModel().to(device)

# add class weighting to handle class imbalance (background usually dominates)
# gives more importance to the minority classes
class_weights = torch.tensor([0.1, 1.5, 1.5, 1.5]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [24]:
# training with validation after each epoch
best_val_loss = float('inf')
train_losses = []
val_losses = []
patience = 3  # number of epochs to wait after validation loss and train loss divergence before stopping
early_stop_counter = 0
diverging = False

for epoch in range(EPOCHS):
    # training
    model.train()
    total_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Training"):
        images, masks = images.to(device), masks.to(device)

        # checking if masks contain our classes
        mask_classes = torch.unique(masks)
        if len(mask_classes) == 1 and mask_classes[0] == 0:
            print("Warning: This batch only contains background class")
            continue  # skip batch if it only contains background

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} - Validation"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    # store losses for tracking divergence
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    # saving the model if it's the best so far
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_segmentation_model.pth')
        print(f"Saved new best model with validation loss: {best_val_loss:.4f}")
        early_stop_counter = 0  # Reset counter when we find a better model
    else:
        early_stop_counter += 1
        print(f"Validation loss did not improve for {early_stop_counter} epoch(s)")

    # sheck for divergence (train loss decreasing while val loss increasing)
    if len(train_losses) >= 2:
        if train_losses[-1] < train_losses[-2] and val_losses[-1] > val_losses[-2]:
            diverging = True
            print(f"Warning: Training and validation losses are diverging (possible overfitting)")
        else:
            diverging = False

    # early stopping condition
    if diverging and early_stop_counter >= patience:
        print(f"Early stopping triggered after epoch {epoch+1} due to loss divergence")
        print(f"Using best model with validation loss: {best_val_loss:.4f}")
        break

Epoch 1/15 - Training:   0%|          | 0/400 [00:00<?, ?it/s]

Found sofa in mask


Epoch 1/15 - Training:   0%|          | 1/400 [00:11<1:13:50, 11.10s/it]

Found aeroplane in mask


Epoch 1/15 - Training:   1%|          | 3/400 [00:22<47:42,  7.21s/it]  

Found dog in mask


Epoch 1/15 - Training:   1%|▏         | 5/400 [00:34<42:14,  6.42s/it]

Found aeroplane in mask


Epoch 1/15 - Training:   2%|▏         | 6/400 [00:44<49:58,  7.61s/it]

Found dog in mask


Epoch 1/15 - Training:   2%|▏         | 8/400 [00:55<42:56,  6.57s/it]

Found aeroplane in mask
Found aeroplane in mask


Epoch 1/15 - Training:   4%|▎         | 14/400 [01:06<18:26,  2.87s/it]

Found aeroplane in mask
Found aeroplane in mask
Found dog in mask


Epoch 1/15 - Training:   5%|▍         | 19/400 [01:30<23:40,  3.73s/it]

Found sofa in mask


Epoch 1/15 - Training:   6%|▌         | 22/400 [01:41<23:30,  3.73s/it]

Found sofa in mask
Found dog in mask


Epoch 1/15 - Training:   6%|▌         | 24/400 [01:52<25:56,  4.14s/it]

Found aeroplane in mask
Found sofa in mask


Epoch 1/15 - Training:   8%|▊         | 31/400 [02:04<15:03,  2.45s/it]

Found dog in mask
Found sofa in mask
Found dog in mask


Epoch 1/15 - Training:   9%|▉         | 36/400 [02:25<19:09,  3.16s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  11%|█         | 43/400 [02:36<13:01,  2.19s/it]



Epoch 1/15 - Training:  12%|█▏        | 48/400 [02:36<08:14,  1.41s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  13%|█▎        | 53/400 [02:48<09:50,  1.70s/it]

Found dog in mask


Epoch 1/15 - Training:  14%|█▎        | 54/400 [02:59<14:37,  2.54s/it]

Found dog in mask


Epoch 1/15 - Training:  14%|█▍        | 55/400 [03:10<20:09,  3.51s/it]

Found dog in mask


Epoch 1/15 - Training:  14%|█▍        | 56/400 [03:21<26:15,  4.58s/it]

Found sofa in mask


Epoch 1/15 - Training:  14%|█▍        | 57/400 [03:32<31:59,  5.60s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  15%|█▍        | 59/400 [03:42<31:03,  5.46s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  17%|█▋        | 68/400 [03:54<12:35,  2.28s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  18%|█▊        | 73/400 [04:05<12:22,  2.27s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  19%|█▉        | 75/400 [04:16<15:34,  2.88s/it]

Found sofa in mask


Epoch 1/15 - Training:  19%|█▉        | 77/400 [04:27<18:25,  3.42s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  20%|█▉        | 79/400 [04:39<21:42,  4.06s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  20%|██        | 82/400 [04:50<20:47,  3.92s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  21%|██        | 83/400 [05:00<24:57,  4.72s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  23%|██▎       | 91/400 [05:11<11:55,  2.32s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  24%|██▍       | 95/400 [05:22<12:39,  2.49s/it]

Found sofa in mask


Epoch 1/15 - Training:  25%|██▍       | 99/400 [05:34<12:59,  2.59s/it]

Found aeroplane in mask


Epoch 1/15 - Training:  27%|██▋       | 107/400 [05:45<08:40,  1.78s/it]

Found dog in mask


Epoch 1/15 - Training:  28%|██▊       | 111/400 [05:56<09:57,  2.07s/it]

Found dog in mask
Found aeroplane in mask


Epoch 1/15 - Training:  28%|██▊       | 113/400 [06:07<12:50,  2.69s/it]

Found sofa in mask


Epoch 1/15 - Training:  28%|██▊       | 114/400 [06:17<16:57,  3.56s/it]

Found sofa in mask


Epoch 1/15 - Training:  29%|██▉       | 116/400 [06:18<15:26,  3.26s/it]


KeyboardInterrupt: 

In [None]:
# load best model for evaluation
model.load_state_dict(torch.load('best_segmentation_model.pth'))

In [None]:
def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    all_preds = []
    all_labels = []

    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            # calculate per-class accuracy
            for c in range(num_classes):
                class_mask = (masks == c)
                class_correct[c] += (preds[class_mask] == c).sum().item()
                class_total[c] += class_mask.sum().item()

            preds = preds.cpu().numpy()
            masks = masks.cpu().numpy()
            all_preds.extend(preds.flatten())
            all_labels.extend(masks.flatten())

    # per-class accuracy
    class_accuracy = []
    for i in range(num_classes):
        if class_total[i] > 0:
            accuracy = class_correct[i] / class_total[i]
            class_accuracy.append(accuracy)
            print(f"Class {CLASSES[i]} accuracy: {accuracy:.4f}")
        else:
            class_accuracy.append(0)
            print(f"Class {CLASSES[i]} has no samples")

    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, average='weighted', labels=range(num_classes), zero_division=0),
        'recall': recall_score(all_labels, all_preds, average='weighted', labels=range(num_classes), zero_division=0),
        'f1': f1_score(all_labels, all_preds, average='weighted', labels=range(num_classes), zero_division=0),
        'confusion_matrix': confusion_matrix(all_labels, all_preds, labels=range(num_classes)),
        'class_accuracy': class_accuracy
    }
    return metrics

In [None]:
metrics = evaluate_model(model, val_loader, device, NUM_CLASSES)

print("\nModel Performance Metrics:")
for k, v in metrics.items():
    if k != 'confusion_matrix' and k != 'class_accuracy':
        print(f"{k.capitalize()}: {v:.4f}")

print("\nConfusion Matrix:")
print(metrics['confusion_matrix'])

In [None]:
def visualize_predictions(model, dataset, device, num_samples=3):
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for idx in indices:
        image, true_mask = dataset[idx]
        image_batch = image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image_batch)
            pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

        # de-normalize image for visualization
        image_np = image.permute(1, 2, 0).numpy()
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)

        def decode_mask(mask):
            color_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
            for class_id, color in enumerate(COLORS):
                color_mask[mask == class_id] = color
            return color_mask

        # print unique class values for debugging
        true_classes = np.unique(true_mask.numpy())
        pred_classes = np.unique(pred_mask)
        print(f"True mask classes: {true_classes}")
        print(f"Predicted mask classes: {pred_classes}")

        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title("Image")
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.imshow(decode_mask(true_mask.numpy()))
        plt.title("Ground Truth")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.imshow(decode_mask(pred_mask))
        plt.title("Prediction")
        plt.axis("off")

        plt.tight_layout()
        plt.show()

In [None]:
print("\nVisualizing predictions:")
visualize_predictions(model, val_dataset, device, num_samples=3)

In [None]:
!pip install fiftyone

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# download OpenImages
print("Downloading OpenImages with class labels...")
dataset = foz.load_zoo_dataset(
    "open-images-v6",
    split="validation",
    label_types=["detections"], # to get ground truth class labels
    classes=["Dog", "Sofa bed", "Airplane"],
    max_samples=100,
    shuffle=True
)

# keep track of which classes exist in each image
class OpenImagesClassDataset(Dataset):
    def __init__(self, fo_dataset, transform=None):
        self.samples = []
        self.class_mapping = {"Airplane": 1, "Sofa bed": 2, "Dog": 3}
        self.transform = transform

        # get ground truth class labels
        for sample in fo_dataset:
            image_path = sample.filepath

            gt_classes = set()
            if hasattr(sample, "ground_truth") and sample.ground_truth is not None:
                if hasattr(sample.ground_truth, "detections") and sample.ground_truth.detections is not None:
                    for detection in sample.ground_truth.detections:
                        class_name = detection.label
                        if class_name in self.class_mapping:
                            gt_classes.add(self.class_mapping[class_name])

            # store sample data
            self.samples.append({
                'path': image_path,
                'gt_classes': gt_classes,
                'filename': os.path.basename(image_path)
            })

        # images per class
        class_counts = {1: 0, 2: 0, 3: 0}
        for sample in self.samples:
            for class_id in sample['gt_classes']:
                class_counts[class_id] += 1

        print(f"Processed {len(self.samples)} images with ground truth class labels")
        print(f"Class distribution in dataset:")
        print(f"- Airplane: {class_counts[1]} images")
        print(f"- Sofa bed: {class_counts[2]} images")
        print(f"- Dog: {class_counts[3]} images")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample['path']
        gt_classes = sample['gt_classes']

        # load and process image
        image = Image.open(image_path).convert('RGB')
        image = image.resize((256, 256), Image.BILINEAR)

        if self.transform:
            image = self.transform(image)

        # convert ground truth to multi-hot encoding
        gt_labels = np.zeros(4)  # [background, airplane, sofa bed, dog]
        for class_id in gt_classes:
            gt_labels[class_id] = 1

        return image, torch.tensor(gt_labels), sample['filename']

# dataset and dataloader
openimages_dataset = OpenImagesClassDataset(dataset, transform=transform)
openimages_loader = DataLoader(openimages_dataset, batch_size=1, shuffle=False)

# 4. Evaluate model's class detection performance
def evaluate_class_detection(model, dataloader, device, num_classes=4):
    model.eval()

    # storage for predictions and ground truth
    all_preds = []
    all_gt = []
    results = []

    with torch.no_grad():
        for image, gt_labels, filename in dataloader:
            image = image.to(device)

            # run model
            output = model(image)
            pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

            # convert segmentation to class detection (if any pixel is predicted as class X, then class X is detected)
            pred_classes = np.zeros(num_classes)
            for class_id in range(num_classes):
                if np.any(pred_mask == class_id):
                    pred_classes[class_id] = 1

            all_preds.append(pred_classes)
            all_gt.append(gt_labels.numpy().squeeze())

            # Store detailed result
            results.append({
                'filename': filename[0],
                'predicted_classes': [i for i in range(num_classes) if pred_classes[i] == 1],
                'ground_truth_classes': [i for i in range(num_classes) if gt_labels[0, i] == 1]
            })

    # convert to np arrays
    all_preds = np.array(all_preds)
    all_gt = np.array(all_gt)

    # per-class metrics
    per_class_metrics = []
    class_names = ['background', 'airplane', 'sofa bed', 'dog']

    for c in range(num_classes):
        class_preds = all_preds[:, c]
        class_gt = all_gt[:, c]

        # calculate metrics (handle case where a class might not be present)
        if np.any(class_gt):
            accuracy = accuracy_score(class_gt, class_preds)
            precision = precision_score(class_gt, class_preds, zero_division=0)
            recall = recall_score(class_gt, class_preds, zero_division=0)
            f1 = f1_score(class_gt, class_preds, zero_division=0)
        else:
            accuracy, precision, recall, f1 = 0, 0, 0, 0

        per_class_metrics.append({
            'class': class_names[c],
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

    # aggregate metrics
    aggregate_metrics = {
        'accuracy': accuracy_score(all_gt.flatten(), all_preds.flatten()),
        'precision': precision_score(all_gt.flatten(), all_preds.flatten(), zero_division=0),
        'recall': recall_score(all_gt.flatten(), all_preds.flatten(), zero_division=0),
        'f1': f1_score(all_gt.flatten(), all_preds.flatten(), zero_division=0)
    }

    return {
        'per_class': per_class_metrics,
        'aggregate': aggregate_metrics,
        'details': results
    }

In [9]:
#  evaluation
print("\nEvaluating class detection on OpenImages:")
metrics = evaluate_class_detection(model, openimages_loader, device)

# display results
print("\nAggregate Metrics:")
for metric, value in metrics['aggregate'].items():
    print(f"{metric.capitalize()}: {value:.4f}")

print("\nPer-Class Metrics:")
for class_metrics in metrics['per_class']:
    class_name = class_metrics['class']
    print(f"\nClass: {class_name}")
    for metric, value in class_metrics.items():
        if metric != 'class':
            print(f"  {metric.capitalize()}: {value:.4f}")


Evaluating class detection on OpenImages:


NameError: name 'evaluate_class_detection' is not defined

In [None]:
# visualize some examples
def visualize_class_detection_results(model, dataloader, device, num_samples=5):
    model.eval()
    class_names = ['background', 'airplane', 'sofa bed', 'dog']

    samples = []
    for data in dataloader:
        samples.append(data)
        if len(samples) >= num_samples:
            break

    for image, gt_labels, filename in samples:
        image = image.to(device)

        output = model(image)
        pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

        pred_classes = []
        for class_id in range(1, 4):  # skip background
            if np.any(pred_mask == class_id):
                pred_classes.append(class_names[class_id])

        # get ground truth classes
        gt_classes = []
        for class_id in range(1, 4):
            if gt_labels[0, class_id] == 1:
                gt_classes.append(class_names[class_id])

        img_np = image.squeeze().permute(1, 2, 0).cpu().numpy()
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)

        # colorize segmentation
        color_mask = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)
        for class_id, color in enumerate(COLORS):
            color_mask[pred_mask == class_id] = color

        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(img_np)
        plt.title(f"Image: {filename[0]}\nGround Truth Classes: {', '.join(gt_classes)}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(color_mask)
        plt.title(f"Segmentation\nPredicted Classes: {', '.join(pred_classes)}")
        plt.axis('off')

        match = set(gt_classes) == set(pred_classes)
        plt.suptitle(f"Class Detection: {'Correct' if match else '✗ Incorrect'}",
                     color='green' if match else 'red', fontsize=16)

        plt.tight_layout()
        plt.show()

In [None]:
print("\nVisualizing class detection results:")
visualize_class_detection_results(model, openimages_loader, device)

In [None]:
!pip install -q git+https://github.com/facebookresearch/segment-anything.git
!pip install -q opencv-python matplotlib
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
import torchvision.transforms as T
import cv2
from segment_anything import sam_model_registry, SamPredictor

# load SAM model
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)

In [None]:
def compare_with_sam(dataset, num_samples=10, model=None, device=None, compute_metrics=True):
    if model is None or device is None:
        print("Model and device must be provided")
        return

    # storage for metrics
    metrics = {
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': [],
        'iou': []
    }

    # get random indices
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    # track average metrics
    all_metrics = []

    for idx in indices:
        image, _, filename = dataset[idx]

        image_np = image.permute(1, 2, 0).numpy()
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)
        image_bgr = (image_np[..., ::-1] * 255).astype(np.uint8)  # RGB → BGR for OpenCV

        # SAM prediction
        predictor.set_image(image_bgr)
        input_point = np.array([[128, 128]])  # center point
        input_label = np.array([1])
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True
        )
        sam_mask = masks[0]  # first mask

        # our model prediction
        model.eval()
        with torch.no_grad():
            pred = model(image.unsqueeze(0).to(device))
            pred_mask = torch.argmax(pred.squeeze(), dim=0).cpu().numpy()

        # convert multi-class prediction to binary (foreground/background), non-background (class 0) pixel is considered foreground (1)
        our_binary_mask = (pred_mask > 0).astype(np.uint8)

        # convert SAM mask to binary
        sam_binary_mask = sam_mask.astype(np.uint8)

        # metrics between our model and SAM
        if compute_metrics:
            # flatten masks
            our_flat = our_binary_mask.flatten()
            sam_flat = sam_binary_mask.flatten()

            # metrics
            accuracy = accuracy_score(sam_flat, our_flat)
            precision = precision_score(sam_flat, our_flat, zero_division=0)
            recall = recall_score(sam_flat, our_flat, zero_division=0)
            f1 = f1_score(sam_flat, our_flat, zero_division=0)

            # IoU
            intersection = np.logical_and(our_binary_mask, sam_binary_mask).sum()
            union = np.logical_or(our_binary_mask, sam_binary_mask).sum()
            iou = intersection / union if union > 0 else 0

            # metrics
            metrics['accuracy'].append(accuracy)
            metrics['precision'].append(precision)
            metrics['recall'].append(recall)
            metrics['f1'].append(f1)
            metrics['iou'].append(iou)

            all_metrics.append({
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'iou': iou
            })

        # visualize
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(image_np)
        plt.title(f"Image: {filename}")
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.imshow(sam_binary_mask, cmap='gray')
        plt.title("Segment Anything Mask")
        plt.axis("off")

        plt.subplot(1, 3, 3)

        decoded = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)
        for class_id, color in enumerate(COLORS):
            decoded[pred_mask == class_id] = color
        plt.imshow(decoded)

        # add metrics to the plot title if available
        if compute_metrics and all_metrics:
            curr_metrics = all_metrics[-1]
            plt.title(f"Model\nF1: {curr_metrics['f1']:.3f}, IoU: {curr_metrics['iou']:.3f}")
        else:
            plt.title("Model Prediction")

        plt.axis("off")

        plt.tight_layout()
        plt.show()

    # Calculate average metrics
    if compute_metrics and metrics['accuracy']:
        avg_metrics = {metric: np.mean(values) for metric, values in metrics.items()}

        # Display results
        print("\nComparison Metrics between our model and SAM:")
        print(f"Accuracy:  {avg_metrics['accuracy']:.4f}")
        print(f"Precision: {avg_metrics['precision']:.4f}")
        print(f"Recall:    {avg_metrics['recall']:.4f}")
        print(f"F1 Score:  {avg_metrics['f1']:.4f}")
        print(f"IoU:       {avg_metrics['iou']:.4f}")


        return avg_metrics

    return None

In [None]:
## import metrics
# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# comparison
print("Comparing the model with SAM on OpenImages:")
comparison_metrics = compare_with_sam(openimages_dataset, num_samples=10, model=model, device=device)

In [None]:
import torchvision.transforms as transforms
import os
# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# define class names and colors
CLASSES = ['background', 'aeroplane', 'sofa', 'dog']
COLORS = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]

# function to process all images in a folder
def process_test_folder(model, device, sam_predictor=None, folder_path="test"):
    # check if folder exists
    if not os.path.exists(folder_path):
        print(f"Error: Folder '{folder_path}' not found.")
        print("Please upload the folder.")
        return None

    # get all image files from the folder
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
    image_files = []

    for file in os.listdir(folder_path):
        ext = os.path.splitext(file)[1].lower()
        if ext in image_extensions:
            image_files.append(os.path.join(folder_path, file))

    if not image_files:
        print(f"No image files found in folder '{folder_path}'")
        return None

    print(f"Found {len(image_files)} images in the folder. Processing...")

    # summary metrics for SAM comparison (if available)
    if sam_predictor is not None:
        sam_metrics = {
            'iou': [],
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1': []
        }

    # process each image
    for i, image_path in enumerate(image_files):
        print(f"\nProcessing image {i+1}/{len(image_files)}: {os.path.basename(image_path)}")

        try:
            # load the image
            image = Image.open(image_path).convert('RGB')

            # show the uploaded image
            plt.figure(figsize=(8, 6))
            plt.imshow(image)
            plt.title(f"Image: {os.path.basename(image_path)}")
            plt.axis('off')
            plt.show()

            # define the same transforms as used during training
            transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])

            # apply transforms
            image_tensor = transform(image).unsqueeze(0)  # add batch dimension
            image_tensor = image_tensor.to(device)

            # run inference with our model
            model.eval()
            with torch.no_grad():
                output = model(image_tensor)
                pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

            # process for visualization
            image_resized = image.resize((256, 256))
            img_np = np.array(image_resized) / 255.0

            # create colored segmentation mask
            color_mask = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)
            for class_id, color in enumerate(COLORS):
                color_mask[pred_mask == class_id] = color

            # count pixels per class and calculate percentages
            class_pixels = {}
            total_pixels = pred_mask.size
            for idx, class_name in enumerate(CLASSES):
                count = np.sum(pred_mask == idx)
                percentage = 100 * count / total_pixels
                class_pixels[class_name] = f"{percentage:.1f}%"

            # determine the dominant class (excluding background)
            dominant_class = None
            max_percent = 0
            for class_name, percentage in class_pixels.items():
                percent_value = float(percentage.strip('%'))
                if class_name != 'background' and percent_value > max_percent:
                    max_percent = percent_value
                    dominant_class = class_name

            if sam_predictor is not None:
                # prepare image for SAM
                sam_image = img_np.copy()
                sam_image = (sam_image * 255).astype(np.uint8)

                # set image for SAM
                sam_predictor.set_image(sam_image)

                # generate mask using center point
                input_point = np.array([[128, 128]])  # center point
                input_label = np.array([1])  # foreground

                masks, _, _ = sam_predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True
                )

                # best mask (first one)
                sam_mask = masks[0]

                # metrics between our model and SAM
                your_binary = (pred_mask > 0).astype(np.uint8)  # Any non-background is 1

                # flatten masks for metrics
                your_flat = your_binary.flatten()
                sam_flat = sam_mask.flatten()

                # metrics
                accuracy = accuracy_score(sam_flat, your_flat)
                precision = precision_score(sam_flat, your_flat, zero_division=0)
                recall = recall_score(sam_flat, your_flat, zero_division=0)
                f1 = f1_score(sam_flat, your_flat, zero_division=0)

                # IoU
                intersection = np.logical_and(your_binary, sam_mask).sum()
                union = np.logical_or(your_binary, sam_mask).sum()
                iou = intersection / union if union > 0 else 0

                # store metrics
                sam_metrics['accuracy'].append(accuracy)
                sam_metrics['precision'].append(precision)
                sam_metrics['recall'].append(recall)
                sam_metrics['f1'].append(f1)
                sam_metrics['iou'].append(iou)

                # visualizations with comparison
                plt.figure(figsize=(15, 5))

                # original image
                plt.subplot(1, 3, 1)
                plt.imshow(img_np)
                plt.title(f"Original: {os.path.basename(image_path)}")
                plt.axis('off')

                # model prediction
                plt.subplot(1, 3, 2)
                plt.imshow(color_mask)
                plt.title(f"Our Model\n{dominant_class or 'Background'}: {class_pixels[dominant_class or 'background']}")
                plt.axis('off')

                # SAM prediction
                plt.subplot(1, 3, 3)
                plt.imshow(sam_mask, cmap='gray')
                plt.title(f"SAM\nIoU: {iou:.3f}, F1: {f1:.3f}")
                plt.axis('off')

                plt.suptitle(f"Model Comparison - Image {i+1}/{len(image_files)}", fontsize=14)
                plt.tight_layout()
                plt.subplots_adjust(top=0.85)
                plt.show()

                # Print metrics
                print(f"\nMetrics comparing our model with SAM:")
                print(f"IoU: {iou:.4f}")
                print(f"Accuracy: {accuracy:.4f}")
                print(f"Precision: {precision:.4f}")
                print(f"Recall: {recall:.4f}")
                print(f"F1 Score: {f1:.4f}")

            else:
                plt.figure(figsize=(12, 4))

                # original image
                plt.subplot(1, 3, 1)
                plt.imshow(img_np)
                plt.title(f"Original: {os.path.basename(image_path)}")
                plt.axis('off')

                # segmentation mask
                plt.subplot(1, 3, 2)
                plt.imshow(color_mask)
                plt.title("Segmentation")
                plt.axis('off')

                # blend of image and mask
                plt.subplot(1, 3, 3)
                blend = img_np.copy()
                # semi-transparent blend
                for i in range(3):
                    blend[:,:,i] = blend[:,:,i] * 0.5 + color_mask[:,:,i] / 255.0 * 0.5
                plt.imshow(blend)
                plt.title("Overlay")
                plt.axis('off')

                plt.suptitle(f"Segmentation Results\n" +
                            " | ".join([f"{cls}: {pct}" for cls, pct in class_pixels.items()]),
                            fontsize=12)

                plt.tight_layout()
                plt.subplots_adjust(top=0.8)
                plt.show()

            # class percentages
            print("\nClass Distribution:")
            for class_name, percentage in class_pixels.items():
                print(f"- {class_name}: {percentage}")

            if dominant_class:
                print(f"\nDominant class: {dominant_class} ({class_pixels[dominant_class]})")
            else:
                print("\nNo dominant class detected besides background")

        except Exception as e:
            print(f"Error processing image '{image_path}': {e}")

    # summary metrics if SAM comparison was performed
    if sam_predictor is not None and sam_metrics['iou']:
        # average metrics
        avg_metrics = {metric: np.mean(values) for metric, values in sam_metrics.items()}

        # summary
        print("\n" + "="*50)
        print("SUMMARY: COMPARISON WITH SAM ACROSS ALL IMAGES")
        print("="*50)
        print(f"Average IoU:       {avg_metrics['iou']:.4f}")
        print(f"Average Accuracy:  {avg_metrics['accuracy']:.4f}")
        print(f"Average Precision: {avg_metrics['precision']:.4f}")
        print(f"Average Recall:    {avg_metrics['recall']:.4f}")
        print(f"Average F1 Score:  {avg_metrics['f1']:.4f}")
        print("="*50)

        return avg_metrics

    return None

In [None]:
process_test_folder(model, device, folder_path="test")

In [None]:
# compare with SAM
process_test_folder(model, device, sam_predictor=predictor, folder_path="test")