# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append(os.path.abspath(".."))       # for 'protonet_STOP_bddoia_modules' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

In [None]:
import csv
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import datetime
import numpy as np
import setproctitle, socket, uuid

from models import get_model
from models.mnistdpl import MnistDPL
from datasets import get_dataset

from argparse import Namespace
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix

from utils import fprint
from utils.status import progress_bar
from utils.metrics import evaluate_metrics, accuracy_binary
from utils.dpl_loss import ADDMNIST_DPL
from utils.checkpoint import save_model
from torch.utils.data import Dataset, DataLoader

from warmup_scheduler import GradualWarmupScheduler
from datasets.utils.base_dataset import BaseDataset

from backbones.bddoia_protonet import ProtoNetConv1D, PrototypicalLoss

from protonet_STOP_bddoia_modules.arguments import args_dpl 
from protonet_STOP_bddoia_modules.proto_modules.proto_helpers import (
    assert_inputs,
    get_random_classes,
    compute_class_logits_per_batch
)
from protonet_STOP_bddoia_modules.proto_modules.proto_functions import (
    build_stop_filtered_concat_inputs,
    get_prototypical_datasets_inputs,
    train_my_prototypical_network,
)

# SETUP

In [None]:
SEED = 3398
UNS_PERCENTAGE = 1.0

In [None]:
args = args_dpl
args.seed = SEED

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "bddoia" 
save_model_name = 'dpl'
save_paths = []
save_path = os.path.join("..",
    "NEW-outputs", 
    save_folder, 
    "my_models", 
    save_model_name,
    f"episodic-proto-net-pipeline-{UNS_PERCENTAGE}-article-STOP"
)
save_paths.append(save_path)

print("Seed: " + str(args.seed))
print(f"Save paths: {str(save_paths)}")

# UTILS

## Test Set Evaluation (alternative to full fledged notebook eval)

In [None]:

# * helper function for 'plot_multilabel_confusion_matrix'
def convert_to_categories(elements):
    # Convert vector of 0s and 1s to a single binary representation along the first dimension
    binary_rep = np.apply_along_axis(
        lambda x: "".join(map(str, x)), axis=1, arr=elements
    )
    return np.array([int(x, 2) for x in binary_rep])


# * BBDOIA custom confusion matrix for concepts
def plot_multilabel_confusion_matrix(
    y_true, y_pred, class_names, title, save_path=None
):
    y_true_categories = convert_to_categories(y_true.astype(int))
    y_pred_categories = convert_to_categories(y_pred.astype(int))

    to_rtn_cm = confusion_matrix(y_true_categories, y_pred_categories)

    cm = multilabel_confusion_matrix(y_true, y_pred)
    num_classes = len(class_names)
    num_rows = (num_classes + 4) // 5  # Calculate the number of rows needed

    plt.figure(figsize=(20, 4 * num_rows))  # Adjust the figure size

    for i in range(num_classes):
        plt.subplot(num_rows, 5, i + 1)  # Set the subplot position
        plt.imshow(cm[i], interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"Class: {class_names[i]}")
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ["0", "1"])
        plt.yticks(tick_marks, ["0", "1"])

        fmt = ".0f"
        thresh = cm[i].max() / 2.0
        for j in range(cm[i].shape[0]):
            for k in range(cm[i].shape[1]):
                plt.text(
                    k,
                    j,
                    format(cm[i][j, k], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i][j, k] > thresh else "black",
                )

        plt.ylabel("True label")
        plt.xlabel("Predicted label")

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.suptitle(title)

    if save_path:
        plt.savefig(f"{save_path}_total.png")
    else:
        plt.show()

    plt.close()

    return to_rtn_cm


# * Concept collapse (Soft)
def compute_coverage(confusion_matrix):
    """Compute the coverage of a confusion matrix.

    Essentially this metric is
    """

    max_values = np.max(confusion_matrix, axis=0)
    clipped_values = np.clip(max_values, 0, 1)

    # Redefinition of soft coverage
    coverage = np.sum(clipped_values) / len(clipped_values)

    return coverage


# * BDDOIA custom confusion matrix for actions
def plot_actions_confusion_matrix(c_true, c_pred, title, save_path=None):

    # Define scenarios and corresponding labels
    scenarios = {
        "forward": [slice(0, 3), slice(0, 3)],
        "stop": [slice(3, 9), slice(3, 9)],
        #'forward_stop': [slice(None, 9), slice(None, 9)],
        "left": [slice(9, 15), slice(9, 15)],
        "right": [slice(15, 21), slice(15, 21)],
    }

    to_rtn = {}

    # Plot confusion matrix for each scenario
    for scenario, indices in scenarios.items():

        g_true = convert_to_categories(c_true[:, indices[0]].astype(int))
        c_pred_scenario = convert_to_categories(c_pred[:, indices[1]].astype(int))

        # Compute confusion matrix
        cm = confusion_matrix(g_true, c_pred_scenario)

        # Plot confusion matrix
        plt.figure()
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"{title} - {scenario}")
        plt.colorbar()

        n_classes = c_true[:, indices[0]].shape[1]

        tick_marks = np.arange(2**n_classes)
        plt.xticks(tick_marks, ["" for _ in range(len(tick_marks))])
        plt.yticks(tick_marks, ["" for _ in range(len(tick_marks))])

        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.tight_layout()

        # Save or show plot
        if save_path:
            plt.savefig(f"{save_path}_{scenario}.png")
        else:
            plt.show()

        to_rtn.update({scenario: cm})

        plt.close()

    return to_rtn

# DATASET & BATCH SAMPLER

In [None]:
class ProtoDataset(Dataset):
    def __init__(self, embeddings, labels):
        assert embeddings.shape[0] == labels.shape[0]
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


class PrototypicalBatchSampler(object):
    """
    Yields a batch of indices for episodic training.
    At each iteration, it randomly selects 'classes_per_it' classes and then picks
    'num_samples' samples for each selected class.
    """
    def __init__(self, labels, classes_per_it, num_samples, iterations):
        """
        Args:
            labels (array-like): 1D array or list of labels for the target task.
                                 This should be either the shape labels or the colour labels.
            classes_per_it (int): Number of random classes for each iteration.
            num_samples (int): Number of samples per class (support + query) in each episode.
            iterations (int): Number of iterations (episodes) per epoch.
        """
        self.labels = np.array(labels)
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations
        
        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        self.classes = torch.LongTensor(self.classes)

        # Create an index matrix of shape (num_classes, max_samples_in_class)
        max_count = max(self.counts)
        self.indexes = np.empty((len(self.classes), max_count), dtype=int)
        self.indexes.fill(-1)
        self.indexes = torch.LongTensor(self.indexes)
        self.numel_per_class = torch.zeros(len(self.classes), dtype=torch.long)

        # Fill in the matrix with indices for each class.
        for idx, label in enumerate(self.labels):
            # Find the row corresponding to this label
            class_idx = (self.classes == label).nonzero(as_tuple=False).item()
            # Find the next available column (where the value is -1)
            pos = (self.indexes[class_idx] == -1).nonzero(as_tuple=False)[0].item()
            self.indexes[class_idx, pos] = idx
            self.numel_per_class[class_idx] += 1

    def __iter__(self):
        """
        Yield a batch of indices for each episode.
        """
        spc = self.sample_per_class
        cpi = self.classes_per_it

        for _ in range(self.iterations):
            batch = torch.LongTensor(cpi * spc)
            # Randomly choose 'classes_per_it' classes
            c_idxs = torch.randperm(len(self.classes))[:cpi]
            for i, class_idx in enumerate(c_idxs):
                s = slice(i * spc, (i + 1) * spc)

                n_avail = self.numel_per_class[class_idx]
                if spc <= n_avail:
                    # enough examples → sample without replacement
                    perm = torch.randperm(n_avail)
                    sample_idxs = perm[:spc]
                else:
                    # too few → sample with replacement
                    sample_idxs = torch.randint(0, n_avail, (spc,), dtype=torch.long)

                batch[s] = self.indexes[class_idx, sample_idxs]

            # Shuffle the batch indices
            batch = batch[torch.randperm(len(batch))]
            yield batch

    def __len__(self):
        return self.iterations

# UNSUPERVISED DATA AND MODEL LOADING

In [None]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()

encoder, decoder = dataset.get_backbone()
# & Main model
model = get_model(args, encoder, decoder, n_images, c_split)
model.start_optim(args)

# & Prototypical Networks
traffic_lights_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 3
traffic_signs_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 4
car_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 5
pedestrians_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 6
rider_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 7
others_model = ProtoNetConv1D(in_dim=3072).to(model.device) # 8

loss = model.get_loss(args)

print(dataset)
print("Using Dataset: ", dataset)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)
print("Using Traffic Lights Model: ", traffic_lights_model)
print("Using Traffic Signs Model: ", traffic_signs_model)
print("Using Car Model: ", car_model)
print("Using Pedestrians Model: ", pedestrians_model)
print("Using Rider Model: ", rider_model)
print("Using Others Model: ", others_model)

unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders(args=args)

In [None]:
# override the default optimizer to include all the expert models
model.opt = torch.optim.Adam(
    list(model.parameters()) + # the main model
    list(traffic_lights_model.parameters()) + # 3
    list(traffic_signs_model.parameters()) + # 4
    list(car_model.parameters()) + # 5
    list(pedestrians_model.parameters()) + # 6
    list(rider_model.parameters()) + # 7
    list(others_model.parameters()), # 8
    args.lr, weight_decay=args.weight_decay
)

# TRAINING

## Main Loop

In [None]:
def train(
        model: MnistDPL, 
        traffic_lights_model: ProtoNetConv1D,# 3
        traffic_signs_model: ProtoNetConv1D, # 4
        car_model: ProtoNetConv1D, # 5 
        pedestrians_model: ProtoNetConv1D, # 6
        rider_model: ProtoNetConv1D, # 7
        others_model: ProtoNetConv1D,# 8
        _loss: ADDMNIST_DPL,
        save_path: str, 
        train_loader: DataLoader,
        val_loader: DataLoader,
        args: Namespace,
        seed: int = 0,
        debug=False,
    ) -> float:

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    # early stopping
    best_f1_stop_concepts = 0.0
    epochs_no_improve = 0
    
    # scheduler & warmup (not used) for main model
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:
        w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    # Optimizers & Schedulers for PNets
    traffic_lights_optimizer = torch.optim.Adam(traffic_lights_model.parameters()) # 3
    traffic_lights_scheduler = torch.optim.lr_scheduler.StepLR(traffic_lights_optimizer, step_size=10, gamma=0.5)
    traffic_signs_model_optimizer = torch.optim.Adam(traffic_signs_model.parameters()) # 4
    traffic_signs_model_scheduler = torch.optim.lr_scheduler.StepLR(traffic_signs_model_optimizer, step_size=10, gamma=0.5)
    car_model_optimizer = torch.optim.Adam(car_model.parameters()) # 5
    car_model_scheduler = torch.optim.lr_scheduler.StepLR(car_model_optimizer, step_size=10, gamma=0.5)
    pedestrians_model_optimizer = torch.optim.Adam(pedestrians_model.parameters()) # 6
    pedestrians_model_scheduler = torch.optim.lr_scheduler.StepLR(pedestrians_model_optimizer, step_size=10, gamma=0.5)
    rider_model_optimizer = torch.optim.Adam(rider_model.parameters()) # 7
    rider_model_scheduler = torch.optim.lr_scheduler.StepLR(rider_model_optimizer, step_size=10, gamma=0.5)
    others_model_optimizer = torch.optim.Adam(others_model.parameters()) # 8
    others_model_scheduler = torch.optim.lr_scheduler.StepLR(others_model_optimizer, step_size=10, gamma=0.5)
    
    fprint("\n--- Start of Training ---\n")
    model.to(model.device)
    model.opt.zero_grad()
    model.opt.step()

    # used to fetch the first random batch of data to train the prototypical networks
    compute_prototypical_batch = True
        
    # & Training start
    for epoch in range(args.n_epochs):
        print(f"Epoch {epoch + 1}/{args.n_epochs}")

        model.train()
        ys, y_true, cs, cs_true, batch = None, None, None, None, 0
        pNet_loss = PrototypicalLoss(n_support=args.num_support)
                
        execute = True  # when to train the prototypical networks (once per epoch)
        for i, batch in enumerate(train_loader):
            # ------------------ original embeddings
            images_embeddings = torch.stack(batch['embeddings']).to(model.device)
            attr_labels = torch.stack(batch['attr_labels']).to(model.device)
            class_labels = torch.stack(batch['class_labels'])[:,:-1].to(model.device) # exclude the last column
            # ------------------ my extracted features
            images_embeddings_raw = torch.stack(batch['embeddings_raw']).to(model.device)
            detected_rois = batch['rois']
            detected_rois_feats = batch['roi_feats']
            detection_labels = batch['detection_labels']
            detection_scores = batch['detection_scores']
            assert_inputs(images_embeddings, attr_labels, class_labels,
                   detected_rois_feats, detected_rois, detection_labels,
                   detection_scores, images_embeddings_raw)
            # ------------------ combined object + scene features
            
            # & Build Prototypical Inputs
            inputs = build_stop_filtered_concat_inputs(
                images_embeddings_raw, 
                detected_rois, 
                detected_rois_feats, 
                detection_labels, 
                detection_scores
            )

            # ? List of length batch size, each tensor of shape (num_rois, 3072).
            concat_inputs_traffic_lights = inputs[0] # 3
            concat_input_traffic_signs = inputs[1] # 4
            concat_input_cars = inputs[2] # 5
            concat_input_pedestrians = inputs[3] # 6
            concat_input_riders = inputs[4] # 7
            concat_input_others = inputs[5] # 8
            
            # --------------------------------------
            # ^ PROTOTYPICAL NETWORK TRAINING
            # --------------------------------------
            if execute:

                if compute_prototypical_batch:
                    # & Build Prototypical Datasets
                    # * [traffic lights]
                    proto_data_traffic_lights, proto_labels_traffic_lights, sz_traffic_lights = get_prototypical_datasets_inputs(
                            concat_inputs_traffic_lights, 
                            3, 
                            attr_labels, 
                            model.device
                        )
                    proto_dataset_traffic_lights = ProtoDataset(proto_data_traffic_lights, proto_labels_traffic_lights)
                    if (proto_labels_traffic_lights == 1).sum().item() == 0:    continue

                    # * [traffic signs]
                    proto_data_traffic_signs, proto_labels_traffic_signs, sz_traffic_signs = get_prototypical_datasets_inputs(
                            concat_input_traffic_signs, 
                            4, 
                            attr_labels,
                            model.device
                    )
                    proto_dataset_traffic_signs = ProtoDataset(proto_data_traffic_signs, proto_labels_traffic_signs)
                    if (proto_labels_traffic_signs == 1).sum().item() == 0:    continue

                    # * [cars]
                    proto_data_cars, proto_labels_cars, sz_cars = get_prototypical_datasets_inputs(
                            concat_input_cars, 
                            5, 
                            attr_labels, 
                            model.device
                    )
                    proto_dataset_cars = ProtoDataset(proto_data_cars, proto_labels_cars)
                    if (proto_labels_cars == 1).sum().item() == 0:    continue
                    
                    # * [pedestrians]
                    proto_data_pedestrians, proto_labels_pedestrians, sz_pedestrians = get_prototypical_datasets_inputs(
                            concat_input_pedestrians, 
                            6, 
                            attr_labels, 
                            model.device
                    )
                    proto_dataset_pedestrians = ProtoDataset(proto_data_pedestrians, proto_labels_pedestrians)
                    if (proto_labels_pedestrians == 1).sum().item() == 0:    continue
                    
                    # * [riders]
                    proto_data_riders, proto_labels_riders, sz_riders = get_prototypical_datasets_inputs(
                            concat_input_riders, 
                            7, 
                            attr_labels, 
                            model.device
                    )
                    proto_dataset_riders = ProtoDataset(proto_data_riders, proto_labels_riders)
                    if (proto_labels_riders == 1).sum().item() == 0:    continue
                    
                    # * [others]
                    proto_data_others, proto_labels_others, sz_others = get_prototypical_datasets_inputs(
                            concat_input_others, 
                            8, 
                            attr_labels, 
                            model.device
                    )
                    proto_dataset_others = ProtoDataset(proto_data_others, proto_labels_others)
                    if (proto_labels_others == 1).sum().item() == 0:    continue
                
                # & Build Prototypical Batch Samplers and Episodic DataLoaders
                # * [traffic lights]
                proto_sampler_traffic_lights = PrototypicalBatchSampler(
                    labels = proto_labels_traffic_lights.cpu().numpy(),
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_traffic_lights = DataLoader(proto_dataset_traffic_lights, batch_sampler=proto_sampler_traffic_lights)
                
                # * [traffic signs]
                proto_sampler_traffic_signs = PrototypicalBatchSampler(
                    labels = proto_labels_traffic_signs.cpu().numpy(),
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_traffic_signs = DataLoader(proto_dataset_traffic_signs, batch_sampler=proto_sampler_traffic_signs)
                if debug:   print("Number of batches in Traffic Signs DataLoader:", len(proto_loader_traffic_signs))

                # * [cars]
                proto_sampler_cars = PrototypicalBatchSampler(
                    labels = proto_labels_cars.cpu().numpy(),
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_cars = DataLoader(proto_dataset_cars, batch_sampler=proto_sampler_cars)
                if debug:   print("Number of batches in Cars DataLoader:", len(proto_loader_cars))

                # * [pedestrians]
                proto_sampler_pedestrians = PrototypicalBatchSampler(
                    labels = proto_labels_pedestrians.cpu().numpy(),
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_pedestrians = DataLoader(proto_dataset_pedestrians, batch_sampler=proto_sampler_pedestrians)
                if debug:   print("Number of batches in Pedestrians DataLoader:", len(proto_loader_pedestrians))
                
                # * [riders]
                proto_sampler_riders = PrototypicalBatchSampler(
                    labels = proto_labels_riders.cpu().numpy(),
                    classes_per_it = args.classes_per_it,

                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_riders = DataLoader(proto_dataset_riders, batch_sampler=proto_sampler_riders)
                if debug:   print("Number of batches in Riders DataLoader:", len(proto_loader_riders))
                
                # * [others]
                proto_sampler_others = PrototypicalBatchSampler(
                    labels = proto_labels_others.cpu().numpy(), 
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
                proto_loader_others = DataLoader(proto_dataset_others, batch_sampler=proto_sampler_others)
                if debug:   print("Number of batches in Others DataLoader:", len(proto_loader_others))

                print(f"Number of proto_labels_traffic_lights that are 0: {(proto_labels_traffic_lights == 0).sum().item()}, 1: {(proto_labels_traffic_lights == 1).sum().item()}")
                print(f"Number of proto_labels_traffic_signs that are 0: {(proto_labels_traffic_signs == 0).sum().item()}, 1: {(proto_labels_traffic_signs == 1).sum().item()}")
                print(f"Number of proto_labels_pedestrians that are 0: {(proto_labels_pedestrians == 0).sum().item()}, 1: {(proto_labels_pedestrians == 1).sum().item()}")
                print(f"Number of proto_labels_cars that are 0: {(proto_labels_cars == 0).sum().item()}, 1: {(proto_labels_cars == 1).sum().item()}")
                print(f"Number of proto_labels_riders that are 0: {(proto_labels_riders == 0).sum().item()}, 1: {(proto_labels_riders == 1).sum().item()}")
                print(f"Number of proto_labels_others that are 0: {(proto_labels_others == 0).sum().item()}, 1: {(proto_labels_others == 1).sum().item()}")

                # & Train Prototypical Networks
                traffic_lights_model.train()
                traffic_signs_model.train()
                car_model.train()
                pedestrians_model.train()
                rider_model.train()
                others_model.train()
                for e in range(args.proto_epochs):
                    epoch_loss_traffic_lights, epoch_acc_traffic_lights = train_my_prototypical_network(
                        proto_loader_traffic_lights, args.iterations, traffic_lights_model, traffic_lights_optimizer, pNet_loss
                    )
                    epoch_loss_traffic_signs, epoch_acc_traffic_signs = train_my_prototypical_network(
                        proto_loader_traffic_signs, args.iterations, traffic_signs_model, traffic_signs_model_optimizer, pNet_loss
                    )
                    epoch_loss_cars, epoch_acc_cars = train_my_prototypical_network(
                        proto_loader_cars, args.iterations, car_model, car_model_optimizer, pNet_loss
                    )
                    epoch_loss_pedestrians, epoch_acc_pedestrians = train_my_prototypical_network(
                        proto_loader_pedestrians, args.iterations, pedestrians_model, pedestrians_model_optimizer, pNet_loss
                    )
                    epoch_loss_riders, epoch_acc_riders = train_my_prototypical_network(
                        proto_loader_riders, args.iterations, rider_model, rider_model_optimizer, pNet_loss
                    )
                    epoch_loss_others, epoch_acc_others = train_my_prototypical_network(
                        proto_loader_others, args.iterations, others_model, others_model_optimizer, pNet_loss
                    )

                avg_loss_tf = sum(epoch_loss_traffic_lights) / len(epoch_loss_traffic_lights)
                avg_acc_tf  = sum(epoch_acc_traffic_lights)  / len(epoch_acc_traffic_lights)
                print(f"Traffic Lights Features  - Avg Loss: {avg_loss_tf:.4f} | Avg Acc: {avg_acc_tf:.4f}")

                avg_loss_ts = sum(epoch_loss_traffic_signs) / len(epoch_loss_traffic_signs)
                avg_acc_ts  = sum(epoch_acc_traffic_signs)  / len(epoch_acc_traffic_signs)
                print(f"Traffic Signs Features  - Avg Loss: {avg_loss_ts:.4f} | Avg Acc: {avg_acc_ts:.4f}")

                avg_loss_c = sum(epoch_loss_cars) / len(epoch_loss_cars)
                avg_acc_c  = sum(epoch_acc_cars)  / len(epoch_acc_cars)
                print(f"Cars Features  - Avg Loss: {avg_loss_c:.4f} | Avg Acc: {avg_acc_c:.4f}")

                avg_loss_p = sum(epoch_loss_pedestrians) / len(epoch_loss_pedestrians)
                avg_acc_p  = sum(epoch_acc_pedestrians)  / len(epoch_acc_pedestrians)
                print(f"Pedestrians Features  - Avg Loss: {avg_loss_p:.4f} | Avg Acc: {avg_acc_p:.4f}")

                avg_loss_r = sum(epoch_loss_riders) / len(epoch_loss_riders)
                avg_acc_r  = sum(epoch_acc_riders)  / len(epoch_acc_riders)
                print(f"Riders Features  - Avg Loss: {avg_loss_r:.4f} | Avg Acc: {avg_acc_r:.4f}")

                avg_loss_o = sum(epoch_loss_others) / len(epoch_loss_others)
                avg_acc_o  = sum(epoch_acc_others)  / len(epoch_acc_others)
                print(f"Others Features  - Avg Loss: {avg_loss_o:.4f} | Avg Acc: {avg_acc_o:.4f}")

                traffic_lights_scheduler.step()
                traffic_signs_model_scheduler.step()
                car_model_scheduler.step()
                pedestrians_model_scheduler.step()
                rider_model_scheduler.step()
                others_model_scheduler.step()
                
                execute = False
                compute_prototypical_batch = False

            # --------------------------------------
            # ^ MAIN MODEL TRAINING
            # --------------------------------------
            else:
                if random.random() > UNS_PERCENTAGE:
                    continue  # Skip this batch with probability (1 - percentage)

                # & Use the expert Prototypical Networks to compute logits for their classes of expertise
                traffic_lights_model.eval()
                traffic_signs_model.eval()
                car_model.eval()
                pedestrians_model.eval()
                rider_model.eval()
                others_model.eval()

                support_embeddings_traffic_lights, support_labels_traffic_lights = get_random_classes(
                    proto_dataset_traffic_lights.embeddings, proto_dataset_traffic_lights.labels, sz_traffic_lights, 2
                )
                support_embeddings_traffic_signs, support_labels_traffic_signs = get_random_classes(
                    proto_dataset_traffic_signs.embeddings, proto_dataset_traffic_signs.labels, sz_traffic_signs, 2
                )
                support_embeddings_cars, support_labels_cars = get_random_classes(
                    proto_dataset_cars.embeddings, proto_dataset_cars.labels, sz_cars, 2
                )
                support_embeddings_pedestrians, support_labels_pedestrians = get_random_classes(
                    proto_dataset_pedestrians.embeddings, proto_dataset_pedestrians.labels, sz_pedestrians, 2
                )
                support_embeddings_riders, support_labels_riders = get_random_classes(
                    proto_dataset_riders.embeddings, proto_dataset_riders.labels, sz_riders, 2
                )
                support_embeddings_others, support_labels_others = get_random_classes(
                    proto_dataset_others.embeddings, proto_dataset_others.labels, sz_others, 2
                )

                with torch.no_grad():
                    logits_tfs = compute_class_logits_per_batch(concat_inputs_traffic_lights, 
                                support_embeddings_traffic_lights, 
                                support_labels_traffic_lights, 
                                traffic_lights_model,
                            )
                    logits_ts = compute_class_logits_per_batch(concat_input_traffic_signs,
                                support_embeddings_traffic_signs, 
                                support_labels_traffic_signs, 
                                traffic_signs_model,
                            )
                    logits_cars = compute_class_logits_per_batch(concat_input_cars,
                                support_embeddings_cars, 
                                support_labels_cars, 
                                car_model,
                            )
                    logits_peds = compute_class_logits_per_batch(concat_input_pedestrians,
                                support_embeddings_pedestrians, 
                                support_labels_pedestrians, 
                                pedestrians_model,
                            )
                    logits_rid = compute_class_logits_per_batch(concat_input_riders,
                                support_embeddings_riders, 
                                support_labels_riders, 
                                rider_model,
                            )
                    logits_oth = compute_class_logits_per_batch(concat_input_others,
                                support_embeddings_others, 
                                support_labels_others, 
                                others_model,
                            )

                    assert logits_tfs.shape == (len(concat_inputs_traffic_lights),), f"Unexpected logits shape: {logits_tfs.shape}"
                    assert logits_ts.shape == (len(concat_input_traffic_signs),), f"Unexpected logits shape: {logits_ts.shape}"
                    assert logits_cars.shape == (len(concat_input_cars),), f"Unexpected logits shape: {logits_cars.shape}"
                    assert logits_peds.shape == (len(concat_input_pedestrians),), f"Unexpected logits shape: {logits_peds.shape}"
                    assert logits_rid.shape == (len(concat_input_riders),), f"Unexpected logits shape: {logits_rid.shape}"
                    assert logits_oth.shape == (len(concat_input_others),), f"Unexpected logits shape: {logits_oth.shape}"

                # & Main Standard Training
                traffic_lights_model.train()
                traffic_signs_model.train()
                car_model.train()
                pedestrians_model.train()
                rider_model.train()
                others_model.train()

                out_dict = model(images_embeddings, 
                            logits_tfs=logits_tfs, 
                            logits_ts=logits_ts,
                            logits_cars=logits_cars,
                            logits_peds=logits_peds,
                            logits_rid=logits_rid,
                            logits_oth=logits_oth, 
                        )  
                
                out_dict.update({"LABELS": class_labels, "CONCEPTS": attr_labels})
                
                model.opt.zero_grad()
                loss, losses = _loss(out_dict, args)

                loss.backward()
                model.opt.step()

                if ys is None:
                    ys = out_dict["YS"]
                    y_true = out_dict["LABELS"]
                    cs = out_dict["pCS"]
                    cs_true = out_dict["CONCEPTS"]
                else:
                    ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                    y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                    cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                    cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

                progress_bar(i, len(train_loader) - 9, epoch, loss.item())
                
        # --------------------------------------
        # ^ Evaluation phase
        # --------------------------------------    
        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        acc, f1 = accuracy_binary(ys, y_true)
        print("\n Train Label acc: ", acc, "Train Label f1", f1,)
        
        # & Compute the support set for evaluating the Prototypical Networks + Main Model
        model.eval()
        traffic_lights_model.eval()
        traffic_signs_model.eval()
        car_model.eval()
        pedestrians_model.eval()
        rider_model.eval()
        others_model.eval()

        support_embeddings_traffic_lights, support_labels_traffic_lights = get_random_classes(
            proto_dataset_traffic_lights.embeddings, proto_dataset_traffic_lights.labels, sz_traffic_lights, 2
        )
        support_embeddings_traffic_signs, support_labels_traffic_signs = get_random_classes(
            proto_dataset_traffic_signs.embeddings, proto_dataset_traffic_signs.labels, sz_traffic_signs, 2
        )
        support_embeddings_cars, support_labels_cars = get_random_classes(
            proto_dataset_cars.embeddings, proto_dataset_cars.labels, sz_cars, 2
        )
        support_embeddings_pedestrians, support_labels_pedestrians = get_random_classes(
            proto_dataset_pedestrians.embeddings, proto_dataset_pedestrians.labels, sz_pedestrians, 2
        )
        support_embeddings_riders, support_labels_riders = get_random_classes(
            proto_dataset_riders.embeddings, proto_dataset_riders.labels, sz_riders, 2
        )
        support_embeddings_others, support_labels_others = get_random_classes(
            proto_dataset_others.embeddings, proto_dataset_others.labels, sz_others, 2
        )
        my_metrics = evaluate_metrics(
                            model=model, 
                            loader=val_loader,
                            args=args,
                            support_embeddings_traffic_lights=support_embeddings_traffic_lights,
                            support_labels_traffic_lights=support_labels_traffic_lights,
                            support_embeddings_traffic_signs=support_embeddings_traffic_signs,
                            support_labels_traffic_signs=support_labels_traffic_signs,
                            support_embeddings_cars=support_embeddings_cars,
                            support_labels_cars=support_labels_cars,
                            support_embeddings_pedestrians=support_embeddings_pedestrians,
                            support_labels_pedestrians=support_labels_pedestrians,
                            support_embeddings_riders=support_embeddings_riders,
                            support_labels_riders=support_labels_riders,
                            support_embeddings_others=support_embeddings_others,
                            support_labels_others=support_labels_others,
                            traffic_lights_model=traffic_lights_model,
                            traffic_signs_model=traffic_signs_model,
                            car_model=car_model,
                            pedestrians_model=pedestrians_model,
                            rider_model=rider_model,
                            others_model=others_model,
                            eval_concepts=['traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others'],
                        )

        loss = my_metrics[0]
        cacc = my_metrics[1]
        yacc = my_metrics[2]
        f1_stop_concepts = my_metrics[83]

        # update at end of the epoch
        if epoch < args.warmup_steps:   w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1, "F1 C", f1_stop_concepts)
        
        if not args.tuning and f1_stop_concepts > best_f1_stop_concepts:
            print("Saving...")
            # Update best F1 score
            best_f1_stop_concepts = f1_stop_concepts

            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with STOP F1(C) score: {best_f1_stop_concepts}")
            if epoch == 0:
                torch.save(model.state_dict(), save_path.replace(".pth", "_backup.pth"))
                print("Saved backup copy of the model.")

        elif f1_stop_concepts <= best_f1_stop_concepts:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break


    fprint("\n--- End of Training ---\n")

    # return all the prototypical datasets used to train the prototypical networks
    all_prototypical_datasets = {
        "traffic_lights": proto_dataset_traffic_lights,
        "traffic_signs": proto_dataset_traffic_signs,
        "cars": proto_dataset_cars,
        "pedestrians": proto_dataset_pedestrians,
        "riders": proto_dataset_riders,
        "others": proto_dataset_others
    }

    # return all the sizes of the prototypical datasets
    all_sz = {
        "traffic_lights": sz_traffic_lights,
        "traffic_signs": sz_traffic_signs,
        "cars": sz_cars,
        "pedestrians": sz_pedestrians,
        "riders": sz_riders,
        "others": sz_others
    }
    
    return best_f1_stop_concepts, all_prototypical_datasets, all_sz

## Run Training

In [None]:
print(f"*** Training model with seed {args.seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{args.seed}.pth")
print("Saving model in folder: ", save_folder)

# ! make train return all the prototypical datasets built and use them for evaluationd
best_f1_c, all_prototypical_datasets, all_sz = train(
        model=model,
        # ^ Prototypical Networks (start)
        traffic_lights_model=traffic_lights_model, # 3
        traffic_signs_model=traffic_signs_model, # 4
        car_model=car_model, # 5
        pedestrians_model=pedestrians_model, # 6
        rider_model=rider_model, # 7
        others_model=others_model, # 8
        # ^ Prototypical Networks (end)
        train_loader=unsup_train_loader,
        val_loader=unsup_val_loader,
        save_path=save_folder,
        _loss=loss,
        args=args,
        seed=SEED,
)
save_model(model, args, args.seed)  # save the model parameters
print(f"*** Finished training model with seed {args.seed} and best F1 score {best_f1_c}")

print("Training finished.")

# TESTING

## Evaluation Routine

In [None]:
# fetch the datasets and the sizes for the prototypical networks
proto_dataset_traffic_lights = all_prototypical_datasets["traffic_lights"]
proto_dataset_traffic_signs = all_prototypical_datasets["traffic_signs"]
proto_dataset_cars = all_prototypical_datasets["cars"]
proto_dataset_pedestrians = all_prototypical_datasets["pedestrians"]
proto_dataset_riders = all_prototypical_datasets["riders"]
proto_dataset_others = all_prototypical_datasets["others"]

sz_traffic_lights = all_sz["traffic_lights"]
sz_traffic_signs = all_sz["traffic_signs"]
sz_cars = all_sz["cars"]
sz_pedestrians = all_sz["pedestrians"]
sz_riders = all_sz["riders"]
sz_others = all_sz["others"]

print("Prototypical datasets built and saved.")
print("Prototypical datasets sizes: ")
print("Traffic Lights: ", sz_traffic_lights)
print("Traffic Signs: ", sz_traffic_signs)
print("Cars: ", sz_cars)
print("Pedestrians: ", sz_pedestrians)
print("Riders: ", sz_riders)
print("Others: ", sz_others)

In [None]:
def evaluate_my_model(model: MnistDPL, 
        save_path: str, 
        test_loader: DataLoader,
    ):
    
    model.eval()
    traffic_lights_model.eval()
    traffic_signs_model.eval()
    car_model.eval()
    pedestrians_model.eval()
    rider_model.eval()
    others_model.eval()

    support_embeddings_traffic_lights, support_labels_traffic_lights = get_random_classes(
        proto_dataset_traffic_lights.embeddings, proto_dataset_traffic_lights.labels, sz_traffic_lights, 2
    )
    support_embeddings_traffic_signs, support_labels_traffic_signs = get_random_classes(
        proto_dataset_traffic_signs.embeddings, proto_dataset_traffic_signs.labels, sz_traffic_signs, 2
    )
    support_embeddings_cars, support_labels_cars = get_random_classes(
        proto_dataset_cars.embeddings, proto_dataset_cars.labels, sz_cars, 2
    )
    support_embeddings_pedestrians, support_labels_pedestrians = get_random_classes(
        proto_dataset_pedestrians.embeddings, proto_dataset_pedestrians.labels, sz_pedestrians, 2
    )
    support_embeddings_riders, support_labels_riders = get_random_classes(
        proto_dataset_riders.embeddings, proto_dataset_riders.labels, sz_riders, 2
    )
    support_embeddings_others, support_labels_others = get_random_classes(
        proto_dataset_others.embeddings, proto_dataset_others.labels, sz_others, 2
    )
    my_metrics = evaluate_metrics(
                        model=model, 
                        loader=test_loader,
                        args=args,
                        support_embeddings_traffic_lights=support_embeddings_traffic_lights,
                        support_labels_traffic_lights=support_labels_traffic_lights,
                        support_embeddings_traffic_signs=support_embeddings_traffic_signs,
                        support_labels_traffic_signs=support_labels_traffic_signs,
                        support_embeddings_cars=support_embeddings_cars,
                        support_labels_cars=support_labels_cars,
                        support_embeddings_pedestrians=support_embeddings_pedestrians,
                        support_labels_pedestrians=support_labels_pedestrians,
                        support_embeddings_riders=support_embeddings_riders,
                        support_labels_riders=support_labels_riders,
                        support_embeddings_others=support_embeddings_others,
                        support_labels_others=support_labels_others,
                        traffic_lights_model=traffic_lights_model,
                        traffic_signs_model=traffic_signs_model,
                        car_model=car_model,
                        pedestrians_model=pedestrians_model,
                        rider_model=rider_model,
                        others_model=others_model,
                        eval_concepts=['traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others'],
                    )
    
    loss = my_metrics[0]
    cacc = my_metrics[1]
    yacc = my_metrics[2]
    f1 = my_metrics[3]
    # Save all metrics to the log file
    metrics_log_path = save_path.replace(".pth", "_metrics.log")
    with open(metrics_log_path, "a") as log_file:
        log_file.write(f"ACC C: {cacc}, ACC Y: {yacc}, F1 Y: {f1}\n\n")

        def write_metrics(class_name, offset):
            log_file.write(f"{class_name.upper()}\n")
            log_file.write(f"  F1 - Binary:   {my_metrics[offset]:.4f}\n")
            log_file.write(f"  F1 - Macro:    {my_metrics[offset+1]:.4f}\n")
            log_file.write(f"  F1 - Micro:    {my_metrics[offset+2]:.4f}\n")
            log_file.write(f"  F1 - Weighted: {my_metrics[offset+3]:.4f}\n")
            log_file.write(f"  Precision - Binary:   {my_metrics[offset+4]:.4f}\n")
            log_file.write(f"  Precision - Macro:    {my_metrics[offset+5]:.4f}\n")
            log_file.write(f"  Precision - Micro:    {my_metrics[offset+6]:.4f}\n")
            log_file.write(f"  Precision - Weighted: {my_metrics[offset+7]:.4f}\n")
            log_file.write(f"  Recall - Binary:   {my_metrics[offset+8]:.4f}\n")
            log_file.write(f"  Recall - Macro:    {my_metrics[offset+9]:.4f}\n")
            log_file.write(f"  Recall - Micro:    {my_metrics[offset+10]:.4f}\n")
            log_file.write(f"  Recall - Weighted: {my_metrics[offset+11]:.4f}\n")
            log_file.write(f"  Balanced Accuracy: {my_metrics[offset+12]:.4f}\n\n")

        write_metrics("Traffic Light", 4)
        write_metrics("Traffic Sign", 17)
        write_metrics("Car", 30)
        write_metrics("Pedestrian", 43)
        write_metrics("Rider", 56)
        write_metrics("Other", 69)

        log_file.write("EVAL CONCEPTS (Aggregated)\n")
        log_file.write(f"  F1 - Binary:   {my_metrics[82]:.4f}\n")
        log_file.write(f"  F1 - Macro:    {my_metrics[83]:.4f}\n")
        log_file.write(f"  F1 - Micro:    {my_metrics[84]:.4f}\n")
        log_file.write(f"  F1 - Weighted: {my_metrics[85]:.4f}\n")
        log_file.write(f"  Precision - Binary:   {my_metrics[86]:.4f}\n")
        log_file.write(f"  Precision - Macro:    {my_metrics[87]:.4f}\n")
        log_file.write(f"  Precision - Micro:    {my_metrics[88]:.4f}\n")
        log_file.write(f"  Precision - Weighted: {my_metrics[89]:.4f}\n")
        log_file.write(f"  Recall - Binary:   {my_metrics[90]:.4f}\n")
        log_file.write(f"  Recall - Macro:    {my_metrics[91]:.4f}\n")
        log_file.write(f"  Recall - Micro:    {my_metrics[92]:.4f}\n")
        log_file.write(f"  Recall - Weighted: {my_metrics[93]:.4f}\n")
        log_file.write(f"  Balanced Accuracy: {my_metrics[94]:.4f}\n\n")

        
    y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
        evaluate_metrics(
            model=model, 
            loader=test_loader,
            args=args,
            support_embeddings_traffic_lights=support_embeddings_traffic_lights,
            support_labels_traffic_lights=support_labels_traffic_lights,
            support_embeddings_traffic_signs=support_embeddings_traffic_signs,
            support_labels_traffic_signs=support_labels_traffic_signs,
            support_embeddings_cars=support_embeddings_cars,
            support_labels_cars=support_labels_cars,
            support_embeddings_pedestrians=support_embeddings_pedestrians,
            support_labels_pedestrians=support_labels_pedestrians,
            support_embeddings_riders=support_embeddings_riders,
            support_labels_riders=support_labels_riders,
            support_embeddings_others=support_embeddings_others,
            support_labels_others=support_labels_others,
            traffic_lights_model=traffic_lights_model,
            traffic_signs_model=traffic_signs_model,
            car_model=car_model,
            pedestrians_model=pedestrians_model,
            rider_model=rider_model,
            others_model=others_model,
            eval_concepts=['traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others'],
            last=True,
        )
    )
    y_labels = ["stop", "forward", "left", "right"]
    concept_labels = [
        "green_light",      
        "follow",           
        "road_clear",       
        "red_light",        # ! 3
        "traffic_sign",     # ! 4
        "car",              # ! 5
        "person",           # ! 6
        "rider",            # ! 7
        "other_obstacle",   # ! 8
        "left_lane",
        "left_green_light",
        "left_follow",
        "no_left_lane",
        "left_obstacle",
        "letf_solid_line",
        "right_lane",
        "right_green_light",
        "right_follow",
        "no_right_lane",
        "right_obstacle",
        "right_solid_line",
    ]

    plot_multilabel_confusion_matrix(y_true, y_pred, y_labels, "Labels", save_path=save_path)
    cfs = plot_actions_confusion_matrix(c_true, c_pred, "Concepts", save_path=save_path)
    cf = plot_multilabel_confusion_matrix(c_true, c_pred, concept_labels, "Concepts", save_path=save_path)
    print("Concept collapse", 1 - compute_coverage(cf))

    with open(metrics_log_path, "a") as log_file:
        for key, value in cfs.items():
            log_file.write(f"Concept collapse: {key}, {1 - compute_coverage(value):.4f}\n")
            log_file.write("\n")

    print("Evaluation metrics saved to:", metrics_log_path)
    fprint("\n--- End of Evaluation ---\n")

## Run Evaluation

In [None]:
# Initialize the model object
model = get_model(args, encoder, decoder, n_images, c_split)

# Load the model state dictionary into the model object
# save_folder = '../NEW-outputs/bddoia/my_models/dpl/episodic-proto-net-pipeline-1.0-article-STOP/dpl_200_backup.pth'
model_state_dict = torch.load(save_folder)
model.load_state_dict(model_state_dict)

# Evaluate the model
evaluate_my_model(model, save_folder, unsup_test_loader)