In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.optim.swa_utils import AveragedModel, SWALR
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from timm.data import Mixup, create_transform
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


In [2]:
class_to_superclass = {
    **dict.fromkeys(range(0, 5), 0),   # aquatic mammals
    **dict.fromkeys(range(5, 10), 1),  # fish
    **dict.fromkeys(range(10, 15), 2),  # flowers
    **dict.fromkeys(range(15, 20), 3),  # food containers
    **dict.fromkeys(range(20, 25), 4),  # fruit and vegetables
    **dict.fromkeys(range(25, 30), 5),  # household electrical devices
    **dict.fromkeys(range(30, 35), 6),  # household furniture
    **dict.fromkeys(range(35, 40), 7),  # insects
    **dict.fromkeys(range(40, 45), 8),  # large carnivores
    **dict.fromkeys(range(45, 50), 9),  # large man-made outdoor things
    **dict.fromkeys(range(50, 55), 10), # large natural outdoor scenes
    **dict.fromkeys(range(55, 60), 11), # large omnivores and herbivores
    **dict.fromkeys(range(60, 65), 12), # medium-sized mammals
    **dict.fromkeys(range(65, 70), 13), # non-insect invertebrates
    **dict.fromkeys(range(70, 75), 14), # people
    **dict.fromkeys(range(75, 80), 15), # reptiles
    **dict.fromkeys(range(80, 85), 16), # small mammals
    **dict.fromkeys(range(85, 90), 17), # trees
    **dict.fromkeys(range(90, 95), 18), # vehicles 1
    **dict.fromkeys(range(95, 100), 19) # vehicles 2
}

In [3]:
superclass_to_group = {
    2: 0, 17: 0, 4: 0,   # Plants/Parts of plants
    18: 1, 19: 1,        # Vehicles
    13: 2, 7: 2,         # Invertebrates
    1: 3, 0: 3,          # Aquatic animals
    8: 4, 11: 4,         # Large animals
    9: 5, 5: 5, 6: 5, 3: 5,  # Man-made articles
    14: 6,               # People
    15: 7, 12: 7, 16: 7, # Normal Terrestrial Animals
    10: 8                # Outdoor scenes
}

In [4]:
def get_superclass_and_group(label):
    superclass = class_to_superclass[label]
    group = superclass_to_group[superclass]
    return superclass, group

In [5]:
# Load Pre-trained ResNet
# resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

# Unfreeze deeper layers (More trainable parameters)
# for param in model.layer3.parameters():
#     param.requires_grad = True
# for param in model.layer4.parameters():
#     param.requires_grad = True
# for param in model.fc.parameters():
#     param.requires_grad = True

# Modify the Fully Connected Layer
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 100)
)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 75.7MB/s]


In [6]:
train_transform = create_transform(
    input_size=224, is_training=True, auto_augment='rand-m9-mstd0.5-inc1'
)
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
trainset_full = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform) #Load the full dataset

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:05<00:00, 30.2MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [9]:
mixup = Mixup(mixup_alpha=0.2, cutmix_alpha=0.3, label_smoothing=0.1, num_classes=100)


In [10]:
results = {}
def train_and_eval(split_size, criterion):
    train_size = int(split_size * len(trainset_full))
    val_size = len(trainset_full) - train_size
    trainset, valset = random_split(trainset_full, [train_size, val_size])

    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
    testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

    num_epochs = 30
    best_accuracy = 0.0

    optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=5e-4)
    scheduler = OneCycleLR(optimizer, max_lr=0.0005, epochs=num_epochs, steps_per_epoch=len(trainloader), pct_start=0.1)
    scaler = torch.amp.GradScaler('cuda')
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=0.0001)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_class, correct_superclass, correct_group, total = 0, 0, 0, 0

        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            if np.random.rand() < 0.5:

             images, labels = mixup(images, labels)   # Get mixed image and target
             if labels.shape[0] != images.shape[0]:  # Ensure batch size consistency
                 labels = labels[:images.shape[0], :]
             true_classes =labels.argmax(dim=1)

            else:
             true_classes = labels

            optimizer.zero_grad()
            with torch.autocast('cuda'):
                outputs = model(images)
                if labels.dim() == 2:
                     loss = SoftTargetCrossEntropy()(outputs, labels)
                else:
                     loss = criterion(outputs, labels.long())
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            predicted = outputs.argmax(dim=1)

            correct_class += (predicted == true_classes).sum().item()
            superclasses, groups = zip(*[get_superclass_and_group(tc.item()) for tc in true_classes])
            predicted_superclasses = [class_to_superclass[p.item()] for p in predicted]
            predicted_groups = [superclass_to_group[sc] for sc in predicted_superclasses]

            correct_superclass += sum(ps == sc for ps, sc in zip(predicted_superclasses, superclasses))
            correct_group += sum(pg == g for pg, g in zip(predicted_groups, groups))
            total += images.size(0)

        epoch_loss = running_loss / len(trainloader)
        # if epoch%5==0:
        print(f"Epoch [{epoch+1}/{num_epochs}]\nLoss: {epoch_loss:.4f}, Class Acc: {100 * correct_class / total:.2f}%, "
              f"Superclass Acc: {100 * correct_superclass / total:.2f}%, Group Acc: {100 * correct_group / total:.2f}%")

        # Validation
        model.eval()
        val_loss, val_correct_class, val_correct_superclass, val_correct_group, val_total = 0, 0, 0, 0, 0
        with torch.no_grad():
            for images, labels in valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels.long())
                val_loss += loss.item()

                predicted = outputs.argmax(dim=1)
                superclasses, groups = zip(*[get_superclass_and_group(l.item()) for l in labels])
                predicted_superclasses = [class_to_superclass[p.item()] for p in predicted]
                predicted_groups = [superclass_to_group[sc] for sc in predicted_superclasses]

                val_correct_class += (predicted == labels).sum().item()
                val_correct_superclass += sum(ps == sc for ps, sc in zip(predicted_superclasses, superclasses))
                val_correct_group += sum(pg == g for pg, g in zip(predicted_groups, groups))
                val_total += labels.size(0)

        # if epoch%4==0:
        print(f"Val Loss: {val_loss / len(valloader):.4f}, Class Acc: {100 * val_correct_class / val_total:.2f}%, "
              f"Superclass Acc: {100 * val_correct_superclass / val_total:.2f}%, Group Acc: {100 * val_correct_group / val_total:.2f}%")

        scheduler.step()
        if val_correct_class / val_total > best_accuracy:
            best_accuracy = val_correct_class / val_total
            torch.save(model.state_dict(), 'best_model.pth')

        if epoch >= 20:
            swa_model.update_parameters(model)
            swa_scheduler.step()

    with torch.no_grad():
        if any(isinstance(layer, nn.BatchNorm2d) for layer in model.modules()):
               swa_model.cpu()
               torch.optim.swa_utils.update_bn(trainloader, swa_model)
               swa_model.to(device)
    torch.save(swa_model.state_dict(), 'swa_model.pth')
    print("SWA Model Saved!")
    swa_model.to(device)

    # Testing with SWA Model
    swa_model.eval()
    y_true, y_pred, y_true_superclass, y_pred_superclass, y_true_group, y_pred_group = [], [], [], [], [], []
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = torch.stack([swa_model(images) for _ in range(5)]).mean(0)
            predicted = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            for true_class, predicted_class in zip(labels, predicted):
                superclass, group = get_superclass_and_group(true_class)
                predicted_superclass = class_to_superclass[predicted_class]
                predicted_group = superclass_to_group[predicted_superclass]

                y_true.append(true_class)
                y_pred.append(predicted_class)
                y_true_superclass.append(superclass)
                y_pred_superclass.append(predicted_superclass)
                y_true_group.append(group)
                y_pred_group.append(predicted_group)
    cm_class = confusion_matrix(y_true, y_pred)
    cm_superclass = confusion_matrix(y_true_superclass, y_pred_superclass)
    cm_group = confusion_matrix(y_true_group, y_pred_group)
    test_accuracy_class= round(accuracy_score(y_true, y_pred),4)
    test_accuracy_superclass = round(accuracy_score(y_true_superclass, y_pred_superclass),4)
    test_accuracy_group = round(accuracy_score(y_true_group, y_pred_group),4)
    print(f"Test Accuracy (Class): {test_accuracy_class:.4f}")
    print(f"Test Accuracy (Superclass): {test_accuracy_superclass:.4f}")
    print(f"Test Accuracy (Group): {test_accuracy_group:.4f}")
    results[train_size]= {
        'cm_class': cm_class,
        'cm_superclass': cm_superclass,
        'cm_group': cm_group,
        'test_accuracy_class': test_accuracy_class,
        'test_accuracy_superclass': test_accuracy_superclass,
        'test_accuracy_group': test_accuracy_group,
        'best_accuracy': best_accuracy
    }



In [None]:
train_test_splits = {0.7,0.8,0.9}

for split_size in train_test_splits:
  print(f"\n\nSplit size {int(split_size*100)} - {int(100-split_size*100)}")
  criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
  train_and_eval(split_size, criterion)




Split size 70 - 30
Epoch [1/30]
Loss: 4.5990, Class Acc: 1.42%, Superclass Acc: 5.26%, Group Acc: 12.97%
Val Loss: 4.5802, Class Acc: 3.86%, Superclass Acc: 7.21%, Group Acc: 15.41%
Epoch [2/30]
Loss: 4.5687, Class Acc: 3.34%, Superclass Acc: 6.99%, Group Acc: 14.28%
Val Loss: 4.5430, Class Acc: 7.74%, Superclass Acc: 10.86%, Group Acc: 17.87%
Epoch [3/30]
Loss: 4.5267, Class Acc: 5.39%, Superclass Acc: 9.00%, Group Acc: 16.21%
Val Loss: 4.4823, Class Acc: 10.81%, Superclass Acc: 14.06%, Group Acc: 21.33%
Epoch [4/30]
Loss: 4.4719, Class Acc: 7.37%, Superclass Acc: 10.83%, Group Acc: 17.98%
Val Loss: 4.4214, Class Acc: 13.37%, Superclass Acc: 16.67%, Group Acc: 23.45%
Epoch [5/30]
Loss: 4.4163, Class Acc: 8.50%, Superclass Acc: 12.03%, Group Acc: 19.32%
Val Loss: 4.3499, Class Acc: 13.84%, Superclass Acc: 17.40%, Group Acc: 23.87%
Epoch [6/30]
Loss: 4.3606, Class Acc: 9.29%, Superclass Acc: 12.69%, Group Acc: 19.49%
Val Loss: 4.2785, Class Acc: 14.52%, Superclass Acc: 17.71%, Group A

In [None]:

print("\n--- Final Results ---")
for train_size, metrics in results.items():
    print(f"\nTrain Size: {train_size}")
    print(f"Test Accuracy (Class): {metrics['test_accuracy_class']:.4f}")
    print(f"Test Accuracy (Superclass): {metrics['test_accuracy_superclass']:.4f}")
    print(f"Test Accuracy (Group): {metrics['test_accuracy_group']:.4f}")
    print(f"Best Accuracy: {metrics['best_accuracy']:.4f}")

    # Function to plot confusion matrix
    def plot_confusion_matrix(cm, title):
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", annot_kws={"size": 10})
        plt.xticks(rotation=45, ha="right")
        plt.title(f"{title} (Train Size: {train_size})", fontsize=14)
        plt.xlabel("Predicted Label", fontsize=12)
        plt.ylabel("True Label", fontsize=12)
        plt.tight_layout()
        plt.show()

    # Plot confusion matrices
    plot_confusion_matrix(metrics['cm_class'], "Confusion Matrix (Class)")
    plot_confusion_matrix(metrics['cm_superclass'], "Confusion Matrix (Superclass)")
    plot_confusion_matrix(metrics['cm_group'], "Confusion Matrix (Group)")


# Bonus Task

In [None]:
class SeverityWeightedLoss(nn.Module):
    def __init__(self, base_loss=nn.CrossEntropyLoss()):
        super().__init__()
        self.base_loss = base_loss

    def forward(self, outputs, true_classes):
        predicted = outputs.argmax(dim=1)

        # Compute severity levels
        superclasses, groups = zip(*[get_superclass_and_group(tc.item()) for tc in true_classes])
        predicted_superclasses = [class_to_superclass[p.item()] for p in predicted]
        predicted_groups = [superclass_to_group[sc] for sc in predicted_superclasses]

        severity = torch.tensor(
            [0 if p == t else
             1 if ps == sc else
             2 if pg == g else
             3 for p, t, ps, sc, pg, g in zip(predicted, true_classes, predicted_superclasses, superclasses, predicted_groups, groups)],
            dtype=torch.float32,
            device=device
        )

        # Define severity-based weight (higher severity = higher penalty)
        severity_weights = 1 + (severity / 3)  # Normalized to [1, 2]

        # Compute weighted loss
        base_loss = self.base_loss(outputs, true_classes.long())
        weighted_loss = (base_loss * severity_weights).mean()
        return weighted_loss



In [None]:
def bonus_train_and_eval(split_size):
    criterion = SeverityWeightedLoss()

    train_size = int(split_size * len(trainset_full))
    val_size = len(trainset_full) - train_size
    trainset, valset = random_split(trainset_full, [train_size, val_size])

    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
    testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

    num_epochs = 30
    best_accuracy = 0.0

    optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=5e-4)
    scheduler = OneCycleLR(optimizer, max_lr=0.0005, epochs=num_epochs, steps_per_epoch=len(trainloader), pct_start=0.1)
    scaler = torch.amp.GradScaler('cuda')
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=0.0001)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_class, correct_superclass, correct_group, total = 0, 0, 0, 0

        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            if np.random.rand() < 0.5:
                images, labels = mixup(images, labels)
                if labels.shape[0] != images.shape[0]:
                    labels = labels[:images.shape[0], :]
                true_classes = labels.argmax(dim=1)
            else:
                true_classes = labels

            optimizer.zero_grad()
            with torch.autocast('cuda'):
                outputs = model(images)
                loss = criterion(outputs, true_classes)  # ✅ Training uses severity-based loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            predicted = outputs.argmax(dim=1)

            correct_class += (predicted == true_classes).sum().item()
            superclasses, groups = zip(*[get_superclass_and_group(tc.item()) for tc in true_classes])
            predicted_superclasses = [class_to_superclass[p.item()] for p in predicted]
            predicted_groups = [superclass_to_group[sc] for sc in predicted_superclasses]

            correct_superclass += sum(ps == sc for ps, sc in zip(predicted_superclasses, superclasses))
            correct_group += sum(pg == g for pg, g in zip(predicted_groups, groups))
            total += images.size(0)

        epoch_loss = running_loss / len(trainloader)
        if epoch%5==0:
         print(f"Epoch [{epoch+1}/{num_epochs}]\nLoss: {epoch_loss:.4f}, Class Acc: {100 * correct_class / total:.2f}%, "
              f"Superclass Acc: {100 * correct_superclass / total:.2f}%, Group Acc: {100 * correct_group / total:.2f}%")

        # Validation
        model.eval()
        val_loss, val_correct_class, val_correct_superclass, val_correct_group, val_total = 0, 0, 0, 0, 0
        with torch.no_grad():
            for images, labels in valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels.long())  # ✅ Severity loss used in validation
                val_loss += loss.item()

                predicted = outputs.argmax(dim=1)
                superclasses, groups = zip(*[get_superclass_and_group(l.item()) for l in labels])
                predicted_superclasses = [class_to_superclass[p.item()] for p in predicted]
                predicted_groups = [superclass_to_group[sc] for sc in predicted_superclasses]

                val_correct_class += (predicted == labels).sum().item()
                val_correct_superclass += sum(ps == sc for ps, sc in zip(predicted_superclasses, superclasses))
                val_correct_group += sum(pg == g for pg, g in zip(predicted_groups, groups))
                val_total += labels.size(0)
        if epoch%4==0:
           print(f"Val Loss: {val_loss / len(valloader):.4f}, Class Acc: {100 * val_correct_class / val_total:.2f}%, "
              f"Superclass Acc: {100 * val_correct_superclass / val_total:.2f}%, Group Acc: {100 * val_correct_group / val_total:.2f}%")

        scheduler.step()
        if val_correct_class / val_total > best_accuracy:
            best_accuracy = val_correct_class / val_total
            torch.save(model.state_dict(), 'best_model.pth')

        if epoch >= 20:
            swa_model.update_parameters(model)
            swa_scheduler.step()

    with torch.no_grad():
        if any(isinstance(layer, nn.BatchNorm2d) for layer in model.modules()):
               swa_model.cpu()
               torch.optim.swa_utils.update_bn(trainloader, swa_model)
               swa_model.to(device)
    torch.save(swa_model.state_dict(), 'swa_model.pth')
    print("SWA Model Saved!")
    swa_model.to(device)

    # Testing with SWA Model
    swa_model.eval()
    y_true, y_pred, y_true_superclass, y_pred_superclass, y_true_group, y_pred_group = [], [], [], [], [], []
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = torch.stack([swa_model(images) for _ in range(5)]).mean(0)
            predicted = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            for true_class, predicted_class in zip(labels, predicted):
                superclass, group = get_superclass_and_group(true_class)
                predicted_superclass = class_to_superclass[predicted_class]
                predicted_group = superclass_to_group[predicted_superclass]

                y_true.append(true_class)
                y_pred.append(predicted_class)
                y_true_superclass.append(superclass)
                y_pred_superclass.append(predicted_superclass)
                y_true_group.append(group)
                y_pred_group.append(predicted_group)

    cm_class = confusion_matrix(y_true, y_pred)
    cm_superclass = confusion_matrix(y_true_superclass, y_pred_superclass)
    cm_group = confusion_matrix(y_true_group, y_pred_group)
    test_accuracy_class= round(accuracy_score(y_true, y_pred),4)
    test_accuracy_superclass = round(accuracy_score(y_true_superclass, y_pred_superclass),4)
    test_accuracy_group = round(accuracy_score(y_true_group, y_pred_group),4)
    results[train_size]= {
        'cm_class': cm_class,
        'cm_superclass': cm_superclass,
        'cm_group': cm_group,
        'test_accuracy_class': test_accuracy_class,
        'test_accuracy_superclass': test_accuracy_superclass,
        'test_accuracy_group': test_accuracy_group,
        'best_accuracy': best_accuracy
    }

In [None]:
train_test_splits = {0.7,0.8,0.9}

for split_size in train_test_splits:
  print(f"\n\nSplit size {int(split_size*100)} - {int(100-split_size*100)}")
  bonus_train_and_eval(split_size)


In [None]:

print("\n--- Final Results ---")
for train_size, metrics in results.items():
    print(f"\nTrain Size: {train_size}")
    print(f"Test Accuracy (Class): {metrics['test_accuracy_class']:.4f}")
    print(f"Test Accuracy (Superclass): {metrics['test_accuracy_superclass']:.4f}")
    print(f"Test Accuracy (Group): {metrics['test_accuracy_group']:.4f}")
    print(f"Best Accuracy: {metrics['best_accuracy']:.4f}")

    # Function to plot confusion matrix
    def plot_confusion_matrix(cm, title):
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", annot_kws={"size": 10})
        plt.xticks(rotation=45, ha="right")
        plt.title(f"{title} (Train Size: {train_size})", fontsize=14)
        plt.xlabel("Predicted Label", fontsize=12)
        plt.ylabel("True Label", fontsize=12)
        plt.tight_layout()
        plt.show()

    # Plot confusion matrices
    plot_confusion_matrix(metrics['cm_class'], "Confusion Matrix (Class)")
    plot_confusion_matrix(metrics['cm_superclass'], "Confusion Matrix (Superclass)")
    plot_confusion_matrix(metrics['cm_group'], "Confusion Matrix (Group)")
