# IMPORTS

In [1]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append(os.path.abspath(".."))       # for 'protonet_STOP_bddoia_modules' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

['/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia/notebooks', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python38.zip', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/lib-dynload', '', '/users-1/eleonora/.local/lib/python3.8/site-packages', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/site-packages', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation', '/users-1/eleonora/reasoning-shortcuts/IXShort']


In [2]:
import csv
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import datetime
import numpy as np
import setproctitle, socket, uuid

from typing import List

from models import get_model
from models.mnistdpl import MnistDPL
from datasets import get_dataset

from argparse import Namespace
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix

from utils import fprint
from utils.status import progress_bar
from utils.metrics import evaluate_metrics
from utils.dpl_loss import ADDMNIST_DPL
from utils.checkpoint import save_model
from torch.utils.data import Dataset, DataLoader

from warmup_scheduler import GradualWarmupScheduler

from backbones.bddoia_protonet import PrototypicalLoss
from protonet_bddoia_modules.data_modules.proto_data import build_prototypical_dataloaders  # TODO: use to refactor
from protonet_bddoia_modules.arguments import args_dpl 
from baseline_modules.supervision_modules.build_sup_set_joint import get_augmented_train_loader  # TODO: use to refactor
from protonet_STOP_bddoia_modules.proto_modules.proto_helpers import (
    assert_inputs,
    get_random_classes,
)
from protonet_STOP_bddoia_modules.proto_modules.proto_functions import (
    train_my_prototypical_network,
)

# SETUP

In [3]:
SEED = 1
UNS_PERCENTAGE = 1.0

In [4]:
args = args_dpl
args.seed = SEED

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "bddoia" 
save_model_name = 'dpl'
save_paths = []
save_path = os.path.join("..",
    "notebook-outputs", 
    save_folder, 
    "my_models", 
    save_model_name,
    f"episodic-proto-net-pipeline-{UNS_PERCENTAGE}-PROVA"
)
save_paths.append(save_path)

print("Seed: " + str(args.seed))
print(f"Save paths: {str(save_paths)}")

Seed: 1
Save paths: ['../notebook-outputs/bddoia/my_models/dpl/episodic-proto-net-pipeline-1.0-PROVA']


# UTILS

## Test Set Evaluation

In [5]:

# * helper function for 'plot_multilabel_confusion_matrix'
def convert_to_categories(elements):
    # Convert vector of 0s and 1s to a single binary representation along the first dimension
    binary_rep = np.apply_along_axis(
        lambda x: "".join(map(str, x)), axis=1, arr=elements
    )
    return np.array([int(x, 2) for x in binary_rep])


# * BBDOIA custom confusion matrix for concepts
def plot_multilabel_confusion_matrix(
    y_true, y_pred, class_names, title, save_path=None
):
    y_true_categories = convert_to_categories(y_true.astype(int))
    y_pred_categories = convert_to_categories(y_pred.astype(int))

    to_rtn_cm = confusion_matrix(y_true_categories, y_pred_categories)

    cm = multilabel_confusion_matrix(y_true, y_pred)
    num_classes = len(class_names)
    num_rows = (num_classes + 4) // 5  # Calculate the number of rows needed

    plt.figure(figsize=(20, 4 * num_rows))  # Adjust the figure size

    for i in range(num_classes):
        plt.subplot(num_rows, 5, i + 1)  # Set the subplot position
        plt.imshow(cm[i], interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"Class: {class_names[i]}")
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ["0", "1"])
        plt.yticks(tick_marks, ["0", "1"])

        fmt = ".0f"
        thresh = cm[i].max() / 2.0
        for j in range(cm[i].shape[0]):
            for k in range(cm[i].shape[1]):
                plt.text(
                    k,
                    j,
                    format(cm[i][j, k], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i][j, k] > thresh else "black",
                )

        plt.ylabel("True label")
        plt.xlabel("Predicted label")

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.suptitle(title)

    if save_path:
        plt.savefig(f"{save_path}_total.png")
    else:
        plt.show()

    plt.close()

    return to_rtn_cm


# * Concept collapse (Soft)
def compute_coverage(confusion_matrix):
    """Compute the coverage of a confusion matrix.

    Essentially this metric is
    """

    max_values = np.max(confusion_matrix, axis=0)
    clipped_values = np.clip(max_values, 0, 1)

    # Redefinition of soft coverage
    coverage = np.sum(clipped_values) / len(clipped_values)

    return coverage


# * BDDOIA custom confusion matrix for actions
def plot_actions_confusion_matrix(c_true, c_pred, title, save_path=None):

    my_scenarios = {
        "forward": [slice(0, 3), slice(0, 3)],  
        "stop": [slice(3, 9), slice(3, 9)],
        "left": [slice(9, 11), slice(18,20)],
        "right": [slice(12, 17), slice(12,17)],
    }

    to_rtn = {}

    # Plot confusion matrix for each scenario
    for scenario, indices in my_scenarios.items():

        g_true = convert_to_categories(c_true[:, indices[0]].astype(int))
        c_pred_scenario = convert_to_categories(c_pred[:, indices[1]].astype(int))

        # Compute confusion matrix
        cm = confusion_matrix(g_true, c_pred_scenario)

        # Plot confusion matrix
        plt.figure()
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"{title} - {scenario}")
        plt.colorbar()

        n_classes = c_true[:, indices[0]].shape[1]

        tick_marks = np.arange(2**n_classes)
        plt.xticks(tick_marks, ["" for _ in range(len(tick_marks))])
        plt.yticks(tick_marks, ["" for _ in range(len(tick_marks))])

        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.tight_layout()

        # Save or show plot
        if save_path:
            plt.savefig(f"{save_path}_{scenario}.png")
        else:
            plt.show()

        to_rtn.update({scenario: cm})

        plt.close()

    return to_rtn

## Other Utils

In [6]:

# * method used to check if all encoder parameters are registered in the model's optimizer
def check_optimizer_params(model):
    """Check that all encoder parameters are registered in the optimizer."""
    # Get all encoder parameters
    encoder_params = []
    for i in range(21):
        encoder = model.encoder[i]
        for name, param in encoder.named_parameters():
            if not param.requires_grad:
                continue  # skip frozen params
            encoder_params.append((f"encoder_{i}.{name}", param))

    # Get all parameters in the optimizer
    opt_param_ids = set(id(p) for group in model.opt.param_groups for p in group['params'])

    # Check each encoder param is in the optimizer
    missing = [(name, p.shape) for name, p in encoder_params if id(p) not in opt_param_ids]

    if missing:
        print("⚠️ The following parameters are missing from the optimizer:")
        for name, shape in missing:
            print(f"  - {name}: {shape}")
        raise RuntimeError("Some encoder parameters are not registered in the optimizer.")
    else:
        print("✅ All encoder parameters are correctly registered in the optimizer.")

In [7]:

# * semi-deterministic variant of get_random_classes where the positive examples are always the same for a given class index
def get_per_class_support_set(proto_datasets:Dataset, pos_examples:dict, class_idx:int, device:str, debug=False):
    pos_list = pos_examples[class_idx]  
    support_embeddings_pos = torch.stack([ ex['images_embeddings_raw'].unsqueeze(0) for ex in pos_list ], dim=0).to(device)
    support_labels_pos = torch.ones(len(pos_list), dtype=torch.long, device=device)
    num_pos_labels = support_labels_pos.sum().item()
    if debug:
        print(f"Class {class_idx}: {support_labels_pos.shape} embeddings, {support_labels_pos.shape} labels (all 1)")

    proto_labels = proto_datasets[class_idx].labels
    proto_data = proto_datasets[class_idx].embeddings
    
    mask = proto_labels == 0
    proto_data_neg = proto_data[mask]
    proto_labels_neg = proto_labels[mask]
    support_embeddings_neg, support_labels_neg = get_random_classes(
        proto_data_neg, proto_labels_neg, n_support=num_pos_labels, n_classes=1
    )
    if debug:
        print("Support embeddings shape: ", support_embeddings_neg.shape)
        print("Support labels shape: ", support_labels_neg.shape)
    
    assert torch.all(support_labels_neg == 0), "support_labels contains non-zero entries"

    support_embeddings_combined = torch.cat([support_embeddings_pos, support_embeddings_neg], dim=0)
    support_labels_combined = torch.cat([support_labels_pos, support_labels_neg], dim=0)
    if debug:
        print("Combined support embeddings shape:", support_embeddings_combined.shape)
        print("Combined support labels shape:", support_labels_combined.shape)

    return support_embeddings_combined, support_labels_combined

# ANNOTATIONS DATASET & BATCH SAMPLER

In [8]:
class ProtoDataset(Dataset):
    def __init__(self, embeddings, labels):
        assert embeddings.shape[0] == labels.shape[0]
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


class PrototypicalBatchSampler(object):
    """
    Yields a batch of indices for episodic training.
    At each iteration, it randomly selects 'classes_per_it' classes and then picks
    'num_samples' samples for each selected class.
    """
    def __init__(self, labels, classes_per_it, num_samples, iterations):
        """
        Args:
            labels (array-like): 1D array or list of labels for the target task.
                                 This should be either the shape labels or the colour labels.
            classes_per_it (int): Number of random classes for each iteration.
            num_samples (int): Number of samples per class (support + query) in each episode.
            iterations (int): Number of iterations (episodes) per epoch.
        """
        self.labels = np.array(labels)
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations
        
        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        self.classes = torch.LongTensor(self.classes)

        # Create an index matrix of shape (num_classes, max_samples_in_class)
        max_count = max(self.counts)
        self.indexes = np.empty((len(self.classes), max_count), dtype=int)
        self.indexes.fill(-1)
        self.indexes = torch.LongTensor(self.indexes)
        self.numel_per_class = torch.zeros(len(self.classes), dtype=torch.long)

        # Fill in the matrix with indices for each class.
        for idx, label in enumerate(self.labels):
            # Find the row corresponding to this label
            class_idx = (self.classes == label).nonzero(as_tuple=False).item()
            # Find the next available column (where the value is -1)
            pos = (self.indexes[class_idx] == -1).nonzero(as_tuple=False)[0].item()
            self.indexes[class_idx, pos] = idx
            self.numel_per_class[class_idx] += 1

    def __iter__(self):
        """
        Yield a batch of indices for each episode.
        """
        spc = self.sample_per_class
        cpi = self.classes_per_it

        for _ in range(self.iterations):
            batch = torch.LongTensor(cpi * spc)
            # Randomly choose 'classes_per_it' classes
            c_idxs = torch.randperm(len(self.classes))[:cpi]
            for i, class_idx in enumerate(c_idxs):
                s = slice(i * spc, (i + 1) * spc)

                n_avail = self.numel_per_class[class_idx]
                if spc <= n_avail:
                    # enough examples → sample without replacement
                    perm = torch.randperm(n_avail)
                    sample_idxs = perm[:spc]
                else:
                    # too few → sample with replacement
                    sample_idxs = torch.randint(0, n_avail, (spc,), dtype=torch.long)

                batch[s] = self.indexes[class_idx, sample_idxs]

            # Shuffle the batch indices
            batch = batch[torch.randperm(len(batch))]
            yield batch

    def __len__(self):
        return self.iterations

# UNSUPERVISED DATA AND MODEL LOADING

In [9]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()

encoder, decoder = dataset.get_backbone()
assert isinstance(encoder, tuple) and len(encoder) == 21, "encoder must be a tuple of 21 elements"

# & Main model
model = get_model(args, encoder, decoder, n_images, c_split)
model.start_optim(args)
check_optimizer_params(model)
loss = model.get_loss(args)

print(dataset)
print("Using Dataset: ", dataset)
print("Using Model: ", model)
print("Using Loss: ", loss)

unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders(args=args)

Available datasets: ['mnmath', 'xor', 'clipboia', 'shortmnist', 'restrictedmnist', 'minikandinsky', 'presddoia', 'prekandinsky', 'sddoia', 'clipkandinsky', 'addmnist', 'clipshortmnist', 'boia_original', 'boia_original_embedded', 'clipsddoia', 'boia', 'kandinsky', 'halfmnist']
[PROTO-INFO] Using Prototypical Networks as backbone
Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'shieldedmnist', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 

## Check Optimizer

In [10]:
all_params = sum(p.numel() for p in model.opt.param_groups[0]['params'])
pnets_params = sum(p.numel() for enc in model.encoder for p in enc.parameters())
mlp_params = sum(p.numel() for p in model.mlp.parameters()) if args.expressive_model else 0

expected_total = pnets_params + mlp_params

assert all_params == expected_total, (
    f"Mismatch in optimizer parameters!\n"
    f"- Backbone params: {pnets_params:,}\n"
    f"- MLP params: {mlp_params:,} (expressive_model={args.expressive_model})\n"
    f"- Total in optimizer: {all_params:,}\n"
    f"→ Expected total: {expected_total:,}"
)

# FETCHING DATA ANNOTATIONS

## Build positive annotation set for each class

In [11]:

# * 1 POSITIVE EXAMPLES COLLECTION =====
pos_examples = {cls_idx: [] for cls_idx in range(21)}
target_per_class = 6    # Desired number of positives per class
debug = True

# Loop over dataset until we collect target_per_class for each class
for batch_idx, batch in enumerate(unsup_train_loader):
    raw_embs = torch.stack(batch['embeddings_raw']).to(model.device)
    attrs = torch.stack(batch['attr_labels']).to(model.device)  # shape [B,21]
    batch_size = attrs.size(0)

    for b in range(batch_size):
        attr_vector = attrs[b].clone().cpu()
        for cls in torch.nonzero(attr_vector).flatten().tolist():
            if len(pos_examples[cls]) >= target_per_class:
                continue
            example = {
                'source_id': (batch_idx, b),
                'images_embeddings_raw': raw_embs[b].detach().cpu().clone(),
                'attr_labels': attr_vector,
                'is_positive': True
            }
            if debug:
                for key, value in example.items():
                    if torch.is_tensor(value):
                        print(f"{key}: {value.shape}")
                    elif isinstance(value, list) and len(value) and torch.is_tensor(value[0]):
                        print(f"{key}: list of {len(value)} tensors, first shape: {value[0].shape}")
                    else:
                        print(f"{key}: {type(value)}")
                debug = False
            pos_examples[cls].append(example)

    if all(len(pos_examples[c]) >= target_per_class for c in range(21)):    break

source_id: <class 'tuple'>
images_embeddings_raw: torch.Size([2048])
attr_labels: torch.Size([21])
is_positive: <class 'bool'>


## Augment positive sets while building negative ones

In [12]:

# * 2: NEGATIVE EXAMPLES & MULTI-LABEL AUGMENTATION =====
neg_examples = {cls_idx: [] for cls_idx in range(21)}

# Allow multi-label augmentation: add any example with attr_labels[i]==1 to pos_examples[i]
for cls in range(21):
    seen_ids = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 1 and ex['source_id'] not in seen_ids:
                new_ex = ex.copy()
                new_ex['is_positive'] = True
                pos_examples[cls].append(new_ex)
                seen_ids.add(ex['source_id'])

# Build negatives: any example that has attr_labels[i]==0 but appears in any pos_examples of other classes
for cls in range(21):
    seen_ids_pos = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 0 and ex['source_id'] not in seen_ids_pos:
                neg_ex = ex.copy()
                neg_ex['is_positive'] = False
                neg_examples[cls].append(neg_ex)

# Ensure no overlap between pos and neg
for cls in range(21):
    assert not set(ex['source_id'] for ex in neg_examples[cls]) & set(ex['source_id'] for ex in pos_examples[cls]), \
        f"Overlap in pos/neg for class {cls}"


## Turn annotations into tensors

In [13]:

# * 3: BUILD EMBEDDING TENSORS AND LABELS =====
dataset_per_class = {}
for cls in range(21):
    examples = pos_examples[cls] + neg_examples[cls]
    emb_list, label_list, extended_label_list = [], [], []
    for ex in examples:
        emb_list.append(ex['images_embeddings_raw'].unsqueeze(0))
        label_list.append(1 if ex['is_positive'] else 0)
        extended_label_list.append(ex['attr_labels'])

    embeddings_tensor = torch.stack(emb_list).to(model.device)                  # [N,1,2048]
    labels_tensor = torch.tensor(label_list, device=model.device)               # [N]
    extended_labels_tensor = torch.stack(extended_label_list).to(model.device)  # [N,21]
    dataset_per_class[cls] = {'embeddings': embeddings_tensor, 'labels': labels_tensor, 'extended_labels': extended_labels_tensor}

        
for cls in range(21):
    print(
        f"Class {cls}: embeddings shape = {dataset_per_class[cls]['embeddings'].shape}, "
        f"labels shape = {dataset_per_class[cls]['labels'].shape}, "
        f"extended labels shape = {dataset_per_class[cls]['extended_labels'].shape}"
    )

Class 0: embeddings shape = torch.Size([184, 1, 2048]), labels shape = torch.Size([184]), extended labels shape = torch.Size([184, 21])
Class 1: embeddings shape = torch.Size([206, 1, 2048]), labels shape = torch.Size([206]), extended labels shape = torch.Size([206, 21])
Class 2: embeddings shape = torch.Size([188, 1, 2048]), labels shape = torch.Size([188]), extended labels shape = torch.Size([188, 21])
Class 3: embeddings shape = torch.Size([204, 1, 2048]), labels shape = torch.Size([204]), extended labels shape = torch.Size([204, 21])
Class 4: embeddings shape = torch.Size([219, 1, 2048]), labels shape = torch.Size([219]), extended labels shape = torch.Size([219, 21])
Class 5: embeddings shape = torch.Size([228, 1, 2048]), labels shape = torch.Size([228]), extended labels shape = torch.Size([228, 21])
Class 6: embeddings shape = torch.Size([230, 1, 2048]), labels shape = torch.Size([230]), extended labels shape = torch.Size([230, 21])
Class 7: embeddings shape = torch.Size([208, 1, 

## *X-Model*: dataloader istantiation

In [14]:

# * 4: BUILD THE DATALOADER FOR THE EXPRESSIVE MODEL =====
if args.expressive_model:
    all_embeddings = torch.cat([per_class['embeddings'] for per_class in dataset_per_class.values()], dim=0)
    all_labels = torch.cat([per_class['extended_labels'] for per_class in dataset_per_class.values()], dim=0)
    all_embeddings = all_embeddings.cpu()
    all_labels     = all_labels.cpu()
    dataset = ProtoDataset(all_embeddings, all_labels)
    x_loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
else:
    x_loader = None

## *PNets*: Sanity checks

In [15]:

# * 5: CONSISTENCY CHECKS FOR PNETs DATA =====
for cls in range(21):
    emb = dataset_per_class[cls]['embeddings']
    lab = dataset_per_class[cls]['labels']
    # Exact count assertion
    expected_total = len(pos_examples[cls]) + len(neg_examples[cls])
    assert emb.size(0) == expected_total, f"Class {cls} count mismatch: {emb.size(0)} vs {expected_total}"
    # Label values correct
    assert set(lab.tolist()) <= {0,1}, f"Invalid labels for class {cls}"
    # Check that each positive and negative example has correct label
    pos_count = len(pos_examples[cls])
    neg_count = len(neg_examples[cls])
    # positives should be labeled 1 in the first pos_count entries
    for idx in range(pos_count):
        assert lab[idx].item() == 1, f"Positive at wrong pos for class {cls}, idx {idx}"
    # negatives should be labeled 0 in the next neg_count entries
    for idx in range(neg_count):
        assert lab[pos_count + idx].item() == 0, f"Negative at wrong pos for class {cls}, idx {pos_count + idx}"

print("Dataset per class built with explicit flags and no overlaps.")

Dataset per class built with explicit flags and no overlaps.


## *PNets*: Istantiate Batch Samplers for each class

In [16]:

# * 6 : CREATE EPISODIC DATALOADERS FOR PNETs =====
proto_datasets = {}
proto_dataloaders = {}

for cls in range(21):
    proto_data = dataset_per_class[cls]['embeddings']
    proto_labels = dataset_per_class[cls]['labels']
    proto_datasets[cls] = ProtoDataset(proto_data, proto_labels)
    proto_sampler = PrototypicalBatchSampler(
                    labels = proto_labels.cpu().numpy(),
                    classes_per_it = args.classes_per_it,
                    num_samples = args.num_samples,
                    iterations = args.iterations,
                )
    proto_dataloaders[cls] = DataLoader(proto_datasets[cls], batch_sampler=proto_sampler)


# Labels count check for proto_datasets[cls] and proto_dataloaders[cls]
for cls in range(21):
    # Dataset label count
    label_counter_dataset = Counter(proto_datasets[cls].labels.cpu().tolist())
    print(f"Class {cls} - Dataset Label 0 count: {label_counter_dataset[0]}, Label 1 count: {label_counter_dataset[1]}")

    # Dataloader label count
    label_counter_loader = Counter()
    for batch in proto_dataloaders[cls]:
        _, labels = batch
        label_counter_loader.update(labels.tolist())
    print(f"Class {cls} - Dataloader Label 0 count: {label_counter_loader[0]}, Label 1 count: {label_counter_loader[1]}")


# Final Sanity Check
for cls in range(21):
    print(f"Class {cls}: Dataset size = {len(proto_datasets[cls])}, Dataloader batches = {len(proto_dataloaders[cls])}")
    assert len(proto_dataloaders[cls]) == args.iterations, \
        f"Class {cls}: Expected {args.iterations} batches, got {len(proto_dataloaders[cls])}"
    for batch in proto_dataloaders[cls]:
        embeddings, labels = batch
        #print("Batch Embeddings Shape:", embeddings.shape, "Labels Shape:", labes.shape)
        assert embeddings.shape == ((args.num_support + args.num_query) * args.classes_per_it, 1, 2048), \
            f"Embeddings shape mismatch: {embeddings.shape}"
        assert labels.shape == ((args.num_support + args.num_query) * args.classes_per_it,), \
            f"Labels shape mismatch: {labels.shape}"

Class 0 - Dataset Label 0 count: 162, Label 1 count: 22
Class 0 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 1 - Dataset Label 0 count: 193, Label 1 count: 13
Class 1 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 2 - Dataset Label 0 count: 169, Label 1 count: 19
Class 2 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 3 - Dataset Label 0 count: 188, Label 1 count: 16
Class 3 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 4 - Dataset Label 0 count: 209, Label 1 count: 10
Class 4 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 5 - Dataset Label 0 count: 222, Label 1 count: 6
Class 5 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 6 - Dataset Label 0 count: 224, Label 1 count: 6
Class 6 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 7 - Dataset Label 0 count: 192, Label 1 count: 16
Class 7 - Dataloader Label 0 count: 1000, Label 1 count: 1000
Class 8 - Dataset Label 0 count: 229, Label 1 count: 6
Cla

# TRAINING

In [17]:
def train(
        model: MnistDPL, 
        _loss: ADDMNIST_DPL,
        save_path: str, 
        proto_datasets: dict,
        proto_dataloaders: dict,
        train_loader: DataLoader,
        val_loader: DataLoader,
        x_loader: DataLoader,
        args: Namespace,
        seed: int = 0,
        eval_concepts: List[str] = ['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
        'on_right_turn_lane', 'traffic_light_right', 'front_car_right', 
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
        'on_left_turn_lane', 'traffic_light_left', 'front_car_left']
    ) -> float:

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    # early stopping
    best_f1_cacc = 0.0
    epochs_no_improve = 0

    # scheduler & warmup (if used) for main model
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:
        w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)
    
    # --------------------------------------
    # ^ 1. PROTOTYPICAL NETWORKS SCHEDULERS & OPTIMIZERS
    # --------------------------------------
    proto_opts, proto_schs = {}, {}
    for c, name in enumerate(eval_concepts):
        opt = torch.optim.Adam(model.encoder[c].parameters())
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)
        proto_opts[c], proto_schs[c] = opt, sch

    fprint("\n--- Start of Training ---\n")
    for b in range(len(model.encoder)):
        model.encoder[b].train()
        model.encoder[b].to(model.device)
            
    pNet_loss = PrototypicalLoss(n_support=args.num_support)
    for epoch in range(args.n_epochs):

        for e in range(args.proto_epochs):
        # --------------------------------------
        # ^ 2. PROTOTYPICAL NETWORKS TRAINING
        # --------------------------------------
            print('----------------------------------')
            print('--- Prototypical Networks Training ---')                        
            print(f"Prototypical Networks Training Epoch {e + 1}/{args.proto_epochs}")
            losses, accs = {}, {}
            for k, name in enumerate(eval_concepts):
                dl = proto_dataloaders[k]
                opt, sch = proto_opts[k], proto_schs[k]
                l, a = train_my_prototypical_network(
                    dl, args.iterations,
                    model.encoder[k], opt, pNet_loss
                )
                losses[name], accs[name] = l, a
                sch.step()

            for name in eval_concepts:
                avg_l = sum(losses[name]) / len(losses[name])
                avg_a = sum(accs[name])   / len(accs[name])
                print(f"  {name:>25s} - Loss {avg_l:.4f} | Acc {avg_a:.4f}")

        # --------------------------------------
        # ^ 3. X-MODEL TRAINING
        # --------------------------------------
        if args.expressive_model:
            print('------------------')
            print('--- X-Model Training ---')
            model.to(model.device)
            model.train()
            for i, batch in enumerate(x_loader):
                batch_embeds, batch_labels = batch
                batch_embeds = batch_embeds.to(model.device)
                batch_labels = batch_labels.to(model.device)
                assert batch_labels.shape[1] == model.n_facts, (
                    f"batch_labels shape is {batch_labels.shape}, expected (batch_size, {model.n_facts})"
                )

                support_emb_dict = {}
                for j in range(model.n_facts):
                    emb_s, lab_s = get_per_class_support_set(
                        proto_datasets, pos_examples, class_idx=j, device=model.device
                    )
                    support_emb_dict[j] = (emb_s, lab_s)

                model.opt.zero_grad()
                out_dict = model(batch_embeds, support_emb_dict)
                concept_predictions = out_dict["CS"]
                loss = F.binary_cross_entropy(concept_predictions, batch_labels.float())

                loss.backward()
                model.opt.step()

                progress_bar(i, len(x_loader), epoch, loss.item())

        # --------------------------------------
        # ^ 4. MAIN MODEL TRAINING
        # --------------------------------------
        print('------------------')
        print('--- Main Model Training ---')    
        print(f"Main Model Training Epoch {epoch + 1}/{args.n_epochs}")  
        ys, y_true, cs, cs_true, batch = None, None, None, None, 0
        for i, batch in enumerate(train_loader):
            # ------------------ original embeddings
            images_embeddings = torch.stack(batch['embeddings']).to(model.device)
            attr_labels = torch.stack(batch['attr_labels']).to(model.device)
            class_labels = torch.stack(batch['class_labels'])[:,:-1].to(model.device) # exclude the last column
            # ------------------ my extracted features
            images_embeddings_raw = torch.stack(batch['embeddings_raw']).to(model.device)
            detected_rois = batch['rois']
            detected_rois_feats = batch['roi_feats']
            detection_labels = batch['detection_labels']
            detection_scores = batch['detection_scores']
            assert_inputs(images_embeddings, attr_labels, class_labels,
                    detected_rois_feats, detected_rois, detection_labels,
                    detection_scores, images_embeddings_raw)
            
            support_emb_dict = {}
            for j in range(model.n_facts):
                emb_s, lab_s = get_per_class_support_set(
                    proto_datasets, pos_examples, class_idx=j, device=model.device
                )
                support_emb_dict[j] = (emb_s, lab_s)

            if random.random() > UNS_PERCENTAGE:
                continue  # Skip this batch with probability (1 - percentage)
            
            out_dict = model(images_embeddings_raw, support_emb_dict)
            out_dict.update({"LABELS": class_labels, "CONCEPTS": attr_labels})
                
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)
            loss.backward()
            model.opt.step()

            if ys is None:
                    ys = out_dict["YS"]
                    y_true = out_dict["LABELS"]
                    cs = out_dict["pCS"]
                    cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            progress_bar(i, len(train_loader), epoch, loss.item())
            
        # --------------------------------------
        # ^ 5. Evaluation phase
        # --------------------------------------
        model.eval()
        for b in range(len(model.encoder)): 
            model.encoder[b].eval()

        if debug:
            y_pred = torch.argmax(ys, dim=-1)
            print("Argmax predictions have shape: ", y_pred.shape)

        my_metrics = evaluate_metrics(model, val_loader, args, 
                    support_emb_dict=support_emb_dict,
                    eval_concepts=eval_concepts,)

        loss = my_metrics[0]
        cacc = my_metrics[1]
        yacc = my_metrics[2]
        f1_y = my_metrics[3]
       
        # update at end of the epoch
        if epoch < args.warmup_steps:   w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1_y)
        
        if not args.tuning and cacc > best_f1_cacc:
            print("Saving...")
            # Update best F1 score
            best_f1_cacc = cacc
            epochs_no_improve = 0

            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with cacc score: {best_f1_cacc}")
            
        elif cacc <= best_f1_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    fprint("\n--- End of Training ---\n")

    return best_f1_cacc, support_emb_dict
       

## Run Training

In [18]:
print(f"*** Training model with seed {args.seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{args.seed}.pth")
print("Saving model in folder: ", save_folder)

best_f1_c, support_emb_dict = train(
        model=model,
        proto_datasets=proto_datasets,
        proto_dataloaders=proto_dataloaders,
        train_loader=unsup_train_loader,
        val_loader=unsup_val_loader,
        x_loader=x_loader,
        save_path=save_folder,
        _loss=loss,
        args=args,
        seed=SEED,
)
save_model(model, args, args.seed)  # save the model parameters
print(f"*** Finished training model with seed {args.seed} and best CACC score {best_f1_c}")

print("Training finished.")

*** Training model with seed 1
Chosen device: cuda
Saving model in folder:  ../notebook-outputs/bddoia/my_models/dpl/episodic-proto-net-pipeline-1.0-PROVA/dpl_1.pth

--- Start of Training ---

----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 63.40it/s]
100%|██████████| 100/100 [00:00<00:00, 101.46it/s]
100%|██████████| 100/100 [00:00<00:00, 101.21it/s]
100%|██████████| 100/100 [00:01<00:00, 97.59it/s]
100%|██████████| 100/100 [00:00<00:00, 103.54it/s]
100%|██████████| 100/100 [00:00<00:00, 118.64it/s]
100%|██████████| 100/100 [00:01<00:00, 98.95it/s]
100%|██████████| 100/100 [00:00<00:00, 115.87it/s]
100%|██████████| 100/100 [00:00<00:00, 101.32it/s]
100%|██████████| 100/100 [00:01<00:00, 87.57it/s]
100%|██████████| 100/100 [00:00<00:00, 101.05it/s]
100%|██████████| 100/100 [00:00<00:00, 103.67it/s]
100%|██████████| 100/100 [00:00<00:00, 112.08it/s]
100%|██████████| 100/100 [00:00<00:00, 114.34it/s]
100%|██████████| 100/100 [00:00<00:00, 114.04it/s]
100%|██████████| 100/100 [00:00<00:00, 123.55it/s]
100%|██████████| 100/100 [00:00<00:00, 111.78it/s]
100%|██████████| 100/100 [00:00<00:00, 107.71it/s]
100%|██████████| 100/100 [00:00<00:00, 114.54it/s]
100%|██████████| 100/100 [00:01<00:

               green_lights - Loss 10.6145 | Acc 0.5560
             follow_traffic - Loss 15.5993 | Acc 0.5320
                 road_clear - Loss 13.5651 | Acc 0.5310
             traffic_lights - Loss 11.1642 | Acc 0.5620
              traffic_signs - Loss 5.8869 | Acc 0.8950
                       cars - Loss 12.2961 | Acc 0.6700
                pedestrians - Loss 12.6790 | Acc 0.5660
                     riders - Loss 13.9757 | Acc 0.5560
                     others - Loss 9.6395 | Acc 0.7510
               no_lane_left - Loss 11.6033 | Acc 0.5580
         obstacle_left_lane - Loss 12.4582 | Acc 0.5410
            solid_left_line - Loss 15.5246 | Acc 0.5820
         on_right_turn_lane - Loss 10.4804 | Acc 0.6950
        traffic_light_right - Loss 14.0604 | Acc 0.5470
            front_car_right - Loss 15.2369 | Acc 0.5350
              no_lane_right - Loss 12.9419 | Acc 0.5310
        obstacle_right_lane - Loss 14.3094 | Acc 0.5260
           solid_right_line - Loss 12.6107 | Acc 0

[ 09-04 | 15:37 ] epoch 0: |██████████████████████████████████████████████████| loss: 4.35422659

  ACC C 63.46657458278868   ACC Y 49.18981481481481 F1 Y 46.39600776538742
Saving...
Saved best model with cacc score: 63.46657458278868
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 112.93it/s]
100%|██████████| 100/100 [00:00<00:00, 100.39it/s]
100%|██████████| 100/100 [00:00<00:00, 120.97it/s]
100%|██████████| 100/100 [00:00<00:00, 104.92it/s]
100%|██████████| 100/100 [00:01<00:00, 99.80it/s]
100%|██████████| 100/100 [00:00<00:00, 101.58it/s]
100%|██████████| 100/100 [00:00<00:00, 106.71it/s]
100%|██████████| 100/100 [00:00<00:00, 106.63it/s]
100%|██████████| 100/100 [00:01<00:00, 96.92it/s]
100%|██████████| 100/100 [00:01<00:00, 98.71it/s]
100%|██████████| 100/100 [00:00<00:00, 109.86it/s]
100%|██████████| 100/100 [00:00<00:00, 111.08it/s]
100%|██████████| 100/100 [00:00<00:00, 106.69it/s]
100%|██████████| 100/100 [00:00<00:00, 107.74it/s]
100%|██████████| 100/100 [00:00<00:00, 109.67it/s]
100%|██████████| 100/100 [00:00<00:00, 103.56it/s]
100%|██████████| 100/100 [00:00<00:00, 118.46it/s]
100%|██████████| 100/100 [00:00<00:00, 105.16it/s]
100%|██████████| 100/100 [00:00<00:00, 100.24it/s]
100%|██████████| 100/100 [00:01<00

               green_lights - Loss 1.2327 | Acc 0.5080
             follow_traffic - Loss 0.9592 | Acc 0.5510
                 road_clear - Loss 0.7108 | Acc 0.5330
             traffic_lights - Loss 1.6761 | Acc 0.6060
              traffic_signs - Loss 3.5405 | Acc 0.8850
                       cars - Loss 1.9818 | Acc 0.6240
                pedestrians - Loss 0.7690 | Acc 0.5380
                     riders - Loss 1.1029 | Acc 0.5720
                     others - Loss 1.2769 | Acc 0.7420
               no_lane_left - Loss 0.7531 | Acc 0.5280
         obstacle_left_lane - Loss 1.0211 | Acc 0.5320
            solid_left_line - Loss 1.0349 | Acc 0.5560
         on_right_turn_lane - Loss 1.3093 | Acc 0.6100
        traffic_light_right - Loss 1.5525 | Acc 0.5760
            front_car_right - Loss 1.7040 | Acc 0.5160
              no_lane_right - Loss 1.0729 | Acc 0.5140
        obstacle_right_lane - Loss 1.5471 | Acc 0.5490
           solid_right_line - Loss 1.6266 | Acc 0.4470
          

[ 09-04 | 15:40 ] epoch 1: |██████████████████████████████████████████████████| loss: 2.92285562

  ACC C 56.58206707901425   ACC Y 47.413917824074076 F1 Y 27.59282355646202
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 97.45it/s]
100%|██████████| 100/100 [00:00<00:00, 103.95it/s]
100%|██████████| 100/100 [00:00<00:00, 116.45it/s]
100%|██████████| 100/100 [00:00<00:00, 111.54it/s]
100%|██████████| 100/100 [00:00<00:00, 111.26it/s]
100%|██████████| 100/100 [00:00<00:00, 106.54it/s]
100%|██████████| 100/100 [00:01<00:00, 94.90it/s]
100%|██████████| 100/100 [00:00<00:00, 103.75it/s]
100%|██████████| 100/100 [00:00<00:00, 105.62it/s]
100%|██████████| 100/100 [00:00<00:00, 100.99it/s]
100%|██████████| 100/100 [00:00<00:00, 101.99it/s]
100%|██████████| 100/100 [00:01<00:00, 97.96it/s]
100%|██████████| 100/100 [00:00<00:00, 107.47it/s]
100%|██████████| 100/100 [00:00<00:00, 116.92it/s]
100%|██████████| 100/100 [00:00<00:00, 108.92it/s]
100%|██████████| 100/100 [00:00<00:00, 112.48it/s]
100%|██████████| 100/100 [00:01<00:00, 94.09it/s]
100%|██████████| 100/100 [00:00<00:00, 109.51it/s]
100%|██████████| 100/100 [00:00<00:00, 107.65it/s]
100%|██████████| 100/100 [00:00<00:

               green_lights - Loss 0.8888 | Acc 0.5950
             follow_traffic - Loss 1.3564 | Acc 0.5600
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.8224 | Acc 0.5410
              traffic_signs - Loss 1.1097 | Acc 0.8820
                       cars - Loss 0.5831 | Acc 0.7430
                pedestrians - Loss 0.6933 | Acc 0.5240
                     riders - Loss 0.8184 | Acc 0.5790
                     others - Loss 0.4893 | Acc 0.7590
               no_lane_left - Loss 0.6931 | Acc 0.5010
         obstacle_left_lane - Loss 0.6349 | Acc 0.7090
            solid_left_line - Loss 0.6146 | Acc 0.6970
         on_right_turn_lane - Loss 1.0646 | Acc 0.5580
        traffic_light_right - Loss 0.6794 | Acc 0.5830
            front_car_right - Loss 0.7056 | Acc 0.5160
              no_lane_right - Loss 0.7624 | Acc 0.5460
        obstacle_right_lane - Loss 0.8405 | Acc 0.5400
           solid_right_line - Loss 0.7626 | Acc 0.4470
          

[ 09-04 | 15:43 ] epoch 2: |██████████████████████████████████████████████████| loss: 2.78687882

  ACC C 58.204504185252716   ACC Y 49.576822916666664 F1 Y 34.36952659886376
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 121.37it/s]
100%|██████████| 100/100 [00:00<00:00, 116.16it/s]
100%|██████████| 100/100 [00:00<00:00, 111.45it/s]
100%|██████████| 100/100 [00:00<00:00, 102.46it/s]
100%|██████████| 100/100 [00:01<00:00, 93.06it/s]
100%|██████████| 100/100 [00:01<00:00, 99.03it/s]
100%|██████████| 100/100 [00:00<00:00, 132.41it/s]
100%|██████████| 100/100 [00:00<00:00, 112.88it/s]
100%|██████████| 100/100 [00:00<00:00, 102.18it/s]
100%|██████████| 100/100 [00:00<00:00, 115.84it/s]
100%|██████████| 100/100 [00:00<00:00, 104.05it/s]
100%|██████████| 100/100 [00:01<00:00, 92.09it/s]
100%|██████████| 100/100 [00:00<00:00, 114.97it/s]
100%|██████████| 100/100 [00:00<00:00, 101.52it/s]
100%|██████████| 100/100 [00:00<00:00, 100.88it/s]
100%|██████████| 100/100 [00:00<00:00, 105.41it/s]
100%|██████████| 100/100 [00:00<00:00, 106.96it/s]
100%|██████████| 100/100 [00:01<00:00, 92.98it/s]
100%|██████████| 100/100 [00:00<00:00, 100.19it/s]
100%|██████████| 100/100 [00:00<00:

               green_lights - Loss 0.7335 | Acc 0.5720
             follow_traffic - Loss 0.8327 | Acc 0.5810
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 1.4037 | Acc 0.5340
              traffic_signs - Loss 0.7298 | Acc 0.8470
                       cars - Loss 0.3835 | Acc 0.9020
                pedestrians - Loss 0.6911 | Acc 0.5060
                     riders - Loss 0.7887 | Acc 0.5370
                     others - Loss 0.4235 | Acc 0.7910
               no_lane_left - Loss 0.6931 | Acc 0.5010
         obstacle_left_lane - Loss 0.8234 | Acc 0.6790
            solid_left_line - Loss 0.4460 | Acc 0.8410
         on_right_turn_lane - Loss 0.8089 | Acc 0.5940
        traffic_light_right - Loss 0.6915 | Acc 0.5330
            front_car_right - Loss 1.5813 | Acc 0.5150
              no_lane_right - Loss 0.6947 | Acc 0.4790
        obstacle_right_lane - Loss 0.8020 | Acc 0.5460
           solid_right_line - Loss 0.6981 | Acc 0.4740
          

[ 09-04 | 15:46 ] epoch 3: |██████████████████████████████████████████████████| loss: 2.83612776

  ACC C 57.800100247065224   ACC Y 50.0 F1 Y 35.45214426371678
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 150.27it/s]
100%|██████████| 100/100 [00:00<00:00, 151.03it/s]
100%|██████████| 100/100 [00:00<00:00, 151.57it/s]
100%|██████████| 100/100 [00:00<00:00, 151.67it/s]
100%|██████████| 100/100 [00:00<00:00, 151.63it/s]
100%|██████████| 100/100 [00:00<00:00, 151.88it/s]
100%|██████████| 100/100 [00:00<00:00, 151.44it/s]
100%|██████████| 100/100 [00:00<00:00, 151.52it/s]
100%|██████████| 100/100 [00:00<00:00, 151.64it/s]
100%|██████████| 100/100 [00:00<00:00, 151.63it/s]
100%|██████████| 100/100 [00:00<00:00, 150.99it/s]
100%|██████████| 100/100 [00:00<00:00, 150.75it/s]
100%|██████████| 100/100 [00:01<00:00, 96.58it/s]
100%|██████████| 100/100 [00:01<00:00, 95.16it/s]
100%|██████████| 100/100 [00:01<00:00, 99.44it/s]
100%|██████████| 100/100 [00:00<00:00, 101.73it/s]
100%|██████████| 100/100 [00:00<00:00, 106.66it/s]
100%|██████████| 100/100 [00:00<00:00, 110.07it/s]
100%|██████████| 100/100 [00:00<00:00, 138.35it/s]
100%|██████████| 100/100 [00:00<00

               green_lights - Loss 0.7822 | Acc 0.6000
             follow_traffic - Loss 0.9221 | Acc 0.5540
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.8922 | Acc 0.5230
              traffic_signs - Loss 0.4779 | Acc 0.8730
                       cars - Loss 0.1782 | Acc 0.9430
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.7243 | Acc 0.6450
                     others - Loss 0.3907 | Acc 0.8210
               no_lane_left - Loss 0.6931 | Acc 0.5040
         obstacle_left_lane - Loss 0.9511 | Acc 0.5470
            solid_left_line - Loss 0.1783 | Acc 0.9600
         on_right_turn_lane - Loss 0.7991 | Acc 0.5830
        traffic_light_right - Loss 0.6682 | Acc 0.5870
            front_car_right - Loss 0.7167 | Acc 0.5160
              no_lane_right - Loss 0.6935 | Acc 0.4680
        obstacle_right_lane - Loss 1.2980 | Acc 0.4860
           solid_right_line - Loss 0.6941 | Acc 0.4610
          

[ 09-04 | 15:49 ] epoch 4: |██████████████████████████████████████████████████| loss: 2.78131747

  ACC C 56.960290835963356   ACC Y 54.38006365740741 F1 Y 39.206821733887665
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 113.34it/s]
100%|██████████| 100/100 [00:00<00:00, 123.76it/s]
100%|██████████| 100/100 [00:00<00:00, 123.99it/s]
100%|██████████| 100/100 [00:00<00:00, 126.83it/s]
100%|██████████| 100/100 [00:00<00:00, 125.43it/s]
100%|██████████| 100/100 [00:00<00:00, 124.14it/s]
100%|██████████| 100/100 [00:00<00:00, 125.88it/s]
100%|██████████| 100/100 [00:00<00:00, 126.40it/s]
100%|██████████| 100/100 [00:00<00:00, 125.86it/s]
100%|██████████| 100/100 [00:00<00:00, 126.03it/s]
100%|██████████| 100/100 [00:00<00:00, 101.57it/s]
100%|██████████| 100/100 [00:01<00:00, 92.16it/s]
100%|██████████| 100/100 [00:00<00:00, 106.50it/s]
100%|██████████| 100/100 [00:00<00:00, 105.21it/s]
100%|██████████| 100/100 [00:01<00:00, 96.70it/s]
100%|██████████| 100/100 [00:01<00:00, 88.35it/s]
100%|██████████| 100/100 [00:01<00:00, 90.60it/s]
100%|██████████| 100/100 [00:01<00:00, 89.87it/s]
100%|██████████| 100/100 [00:01<00:00, 90.94it/s]
100%|██████████| 100/100 [00:01<00:00

               green_lights - Loss 0.7483 | Acc 0.5390
             follow_traffic - Loss 0.7541 | Acc 0.5950
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.8107 | Acc 0.5370
              traffic_signs - Loss 0.4060 | Acc 0.8900
                       cars - Loss 0.0282 | Acc 0.9930
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.5745 | Acc 0.7460
                     others - Loss 0.1438 | Acc 0.9510
               no_lane_left - Loss 0.6931 | Acc 0.5000
         obstacle_left_lane - Loss 0.6694 | Acc 0.6470
            solid_left_line - Loss 0.1136 | Acc 0.9610
         on_right_turn_lane - Loss 0.8741 | Acc 0.6360
        traffic_light_right - Loss 0.9718 | Acc 0.6030
            front_car_right - Loss 0.6931 | Acc 0.5210
              no_lane_right - Loss 0.6931 | Acc 0.4980
        obstacle_right_lane - Loss 0.7214 | Acc 0.4760
           solid_right_line - Loss 0.6932 | Acc 0.4500
          

[ 09-04 | 15:52 ] epoch 5: |██████████████████████████████████████████████████| loss: 2.77803755

  ACC C 61.200673547055985   ACC Y 55.045572916666664 F1 Y 37.8323209540063
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 117.86it/s]
100%|██████████| 100/100 [00:00<00:00, 108.35it/s]
100%|██████████| 100/100 [00:00<00:00, 108.69it/s]
100%|██████████| 100/100 [00:00<00:00, 104.69it/s]
100%|██████████| 100/100 [00:01<00:00, 94.32it/s]
100%|██████████| 100/100 [00:00<00:00, 100.14it/s]
100%|██████████| 100/100 [00:01<00:00, 99.45it/s]
100%|██████████| 100/100 [00:01<00:00, 86.47it/s]
100%|██████████| 100/100 [00:01<00:00, 85.30it/s]
100%|██████████| 100/100 [00:01<00:00, 85.99it/s]
100%|██████████| 100/100 [00:01<00:00, 91.77it/s]
100%|██████████| 100/100 [00:00<00:00, 105.36it/s]
100%|██████████| 100/100 [00:00<00:00, 105.07it/s]
100%|██████████| 100/100 [00:00<00:00, 105.22it/s]
100%|██████████| 100/100 [00:00<00:00, 104.89it/s]
100%|██████████| 100/100 [00:00<00:00, 103.52it/s]
100%|██████████| 100/100 [00:00<00:00, 132.07it/s]
100%|██████████| 100/100 [00:00<00:00, 151.55it/s]
100%|██████████| 100/100 [00:00<00:00, 151.68it/s]
100%|██████████| 100/100 [00:00<00:00

               green_lights - Loss 0.7253 | Acc 0.5830
             follow_traffic - Loss 0.6917 | Acc 0.5810
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.7284 | Acc 0.5560
              traffic_signs - Loss 0.2473 | Acc 0.9320
                       cars - Loss 0.0300 | Acc 0.9890
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.4228 | Acc 0.8160
                     others - Loss 0.0401 | Acc 0.9890
               no_lane_left - Loss 0.6931 | Acc 0.5040
         obstacle_left_lane - Loss 0.4477 | Acc 0.8770
            solid_left_line - Loss 0.0406 | Acc 0.9840
         on_right_turn_lane - Loss 0.8156 | Acc 0.6400
        traffic_light_right - Loss 0.6933 | Acc 0.4740
            front_car_right - Loss 0.6847 | Acc 0.5270
              no_lane_right - Loss 0.6931 | Acc 0.4730
        obstacle_right_lane - Loss 0.7304 | Acc 0.4770
           solid_right_line - Loss 0.6931 | Acc 0.4470
          

[ 09-04 | 15:55 ] epoch 6: |██████████████████████████████████████████████████| loss: 2.67471957

  ACC C 61.90131836467319   ACC Y 55.812355324074076 F1 Y 39.29279148315151
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 84.85it/s]
100%|██████████| 100/100 [00:00<00:00, 100.67it/s]
100%|██████████| 100/100 [00:01<00:00, 88.91it/s]
100%|██████████| 100/100 [00:01<00:00, 99.95it/s]
100%|██████████| 100/100 [00:00<00:00, 103.68it/s]
100%|██████████| 100/100 [00:00<00:00, 102.78it/s]
100%|██████████| 100/100 [00:00<00:00, 102.41it/s]
100%|██████████| 100/100 [00:01<00:00, 96.26it/s]
100%|██████████| 100/100 [00:01<00:00, 89.16it/s]
100%|██████████| 100/100 [00:00<00:00, 102.24it/s]
100%|██████████| 100/100 [00:00<00:00, 103.66it/s]
100%|██████████| 100/100 [00:00<00:00, 101.80it/s]
100%|██████████| 100/100 [00:00<00:00, 100.18it/s]
100%|██████████| 100/100 [00:01<00:00, 99.40it/s]
100%|██████████| 100/100 [00:01<00:00, 97.36it/s]
100%|██████████| 100/100 [00:01<00:00, 82.87it/s]
100%|██████████| 100/100 [00:01<00:00, 98.80it/s]
100%|██████████| 100/100 [00:01<00:00, 99.66it/s]
100%|██████████| 100/100 [00:01<00:00, 98.62it/s]
100%|██████████| 100/100 [00:00<00:00, 115

               green_lights - Loss 0.7058 | Acc 0.6110
             follow_traffic - Loss 0.6905 | Acc 0.5710
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.7830 | Acc 0.5430
              traffic_signs - Loss 0.1283 | Acc 0.9630
                       cars - Loss 0.0665 | Acc 0.9890
                pedestrians - Loss 0.6924 | Acc 0.5310
                     riders - Loss 0.2765 | Acc 0.9070
                     others - Loss 0.0500 | Acc 0.9790
               no_lane_left - Loss 0.6931 | Acc 0.5020
         obstacle_left_lane - Loss 0.0567 | Acc 0.9810
            solid_left_line - Loss 0.0026 | Acc 1.0000
         on_right_turn_lane - Loss 0.7370 | Acc 0.5570
        traffic_light_right - Loss 0.6932 | Acc 0.4740
            front_car_right - Loss 0.6923 | Acc 0.4970
              no_lane_right - Loss 0.6931 | Acc 0.4900
        obstacle_right_lane - Loss 0.6953 | Acc 0.4670
           solid_right_line - Loss 0.6931 | Acc 0.4430
          

[ 09-04 | 15:58 ] epoch 7: |██████████████████████████████████████████████████| loss: 3.00546217

  ACC C 57.209684782558014   ACC Y 55.67491319444444 F1 Y 39.8304512379514
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 99.94it/s]
100%|██████████| 100/100 [00:00<00:00, 107.37it/s]
100%|██████████| 100/100 [00:01<00:00, 90.06it/s]
100%|██████████| 100/100 [00:01<00:00, 85.81it/s]
100%|██████████| 100/100 [00:01<00:00, 96.53it/s]
100%|██████████| 100/100 [00:00<00:00, 102.99it/s]
100%|██████████| 100/100 [00:01<00:00, 89.54it/s]
100%|██████████| 100/100 [00:01<00:00, 97.69it/s]
100%|██████████| 100/100 [00:01<00:00, 92.31it/s]
100%|██████████| 100/100 [00:00<00:00, 100.75it/s]
100%|██████████| 100/100 [00:00<00:00, 114.57it/s]
100%|██████████| 100/100 [00:01<00:00, 86.34it/s]
100%|██████████| 100/100 [00:01<00:00, 92.99it/s]
100%|██████████| 100/100 [00:00<00:00, 106.80it/s]
100%|██████████| 100/100 [00:00<00:00, 104.15it/s]
100%|██████████| 100/100 [00:00<00:00, 109.86it/s]
100%|██████████| 100/100 [00:00<00:00, 108.69it/s]
100%|██████████| 100/100 [00:00<00:00, 111.61it/s]
100%|██████████| 100/100 [00:00<00:00, 111.49it/s]
100%|██████████| 100/100 [00:00<00:00, 1

               green_lights - Loss 0.6380 | Acc 0.6430
             follow_traffic - Loss 0.6909 | Acc 0.5700
                 road_clear - Loss 0.6931 | Acc 0.5000
             traffic_lights - Loss 0.7377 | Acc 0.5200
              traffic_signs - Loss 0.1563 | Acc 0.9630
                       cars - Loss 0.0574 | Acc 0.9830
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.1416 | Acc 0.9460
                     others - Loss 0.0617 | Acc 0.9880
               no_lane_left - Loss 0.6931 | Acc 0.5050
         obstacle_left_lane - Loss 0.0518 | Acc 0.9840
            solid_left_line - Loss 0.1310 | Acc 0.9640
         on_right_turn_lane - Loss 0.7680 | Acc 0.5560
        traffic_light_right - Loss 0.6932 | Acc 0.4670
            front_car_right - Loss 0.6889 | Acc 0.5180
              no_lane_right - Loss 0.6931 | Acc 0.4810
        obstacle_right_lane - Loss 0.6932 | Acc 0.4680
           solid_right_line - Loss 0.6931 | Acc 0.4960
          

[ 09-04 | 16:01 ] epoch 8: |██████████████████████████████████████████████████| loss: 2.77759433

  ACC C 56.17972993188434   ACC Y 56.95891203703704 F1 Y 40.99537510223115
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 88.67it/s]
100%|██████████| 100/100 [00:01<00:00, 99.44it/s]
100%|██████████| 100/100 [00:00<00:00, 113.56it/s]
100%|██████████| 100/100 [00:00<00:00, 108.84it/s]
100%|██████████| 100/100 [00:00<00:00, 105.50it/s]
100%|██████████| 100/100 [00:01<00:00, 96.52it/s]
100%|██████████| 100/100 [00:00<00:00, 108.11it/s]
100%|██████████| 100/100 [00:00<00:00, 107.82it/s]
100%|██████████| 100/100 [00:00<00:00, 115.34it/s]
100%|██████████| 100/100 [00:00<00:00, 114.82it/s]
100%|██████████| 100/100 [00:00<00:00, 108.11it/s]
100%|██████████| 100/100 [00:00<00:00, 122.20it/s]
100%|██████████| 100/100 [00:00<00:00, 107.72it/s]
100%|██████████| 100/100 [00:00<00:00, 102.62it/s]
100%|██████████| 100/100 [00:00<00:00, 113.59it/s]
100%|██████████| 100/100 [00:00<00:00, 109.40it/s]
100%|██████████| 100/100 [00:00<00:00, 107.60it/s]
100%|██████████| 100/100 [00:00<00:00, 102.06it/s]
100%|██████████| 100/100 [00:01<00:00, 99.54it/s]
100%|██████████| 100/100 [00:01<00:

               green_lights - Loss 0.6399 | Acc 0.6330
             follow_traffic - Loss 0.6928 | Acc 0.5590
                 road_clear - Loss 0.6931 | Acc 0.5010
             traffic_lights - Loss 0.6927 | Acc 0.6220
              traffic_signs - Loss 0.0345 | Acc 0.9930
                       cars - Loss 0.1798 | Acc 0.9490
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0575 | Acc 0.9890
                     others - Loss 0.0384 | Acc 0.9940
               no_lane_left - Loss 0.6931 | Acc 0.5070
         obstacle_left_lane - Loss 0.0191 | Acc 0.9940
            solid_left_line - Loss 0.0719 | Acc 0.9810
         on_right_turn_lane - Loss 0.7221 | Acc 0.5930
        traffic_light_right - Loss 0.6932 | Acc 0.4570
            front_car_right - Loss 0.7237 | Acc 0.5180
              no_lane_right - Loss 0.6931 | Acc 0.5040
        obstacle_right_lane - Loss 0.6931 | Acc 0.4920
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:04 ] epoch 9: |██████████████████████████████████████████████████| loss: 2.79806471

  ACC C 57.172482212384544   ACC Y 55.660445601851855 F1 Y 37.99096508778635
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 105.49it/s]
100%|██████████| 100/100 [00:00<00:00, 101.69it/s]
100%|██████████| 100/100 [00:00<00:00, 112.82it/s]
100%|██████████| 100/100 [00:00<00:00, 114.88it/s]
100%|██████████| 100/100 [00:01<00:00, 88.33it/s]
100%|██████████| 100/100 [00:01<00:00, 96.89it/s]
100%|██████████| 100/100 [00:00<00:00, 111.19it/s]
100%|██████████| 100/100 [00:01<00:00, 96.30it/s]
100%|██████████| 100/100 [00:01<00:00, 96.23it/s]
100%|██████████| 100/100 [00:00<00:00, 100.71it/s]
100%|██████████| 100/100 [00:00<00:00, 102.89it/s]
100%|██████████| 100/100 [00:00<00:00, 108.46it/s]
100%|██████████| 100/100 [00:00<00:00, 116.38it/s]
100%|██████████| 100/100 [00:00<00:00, 100.82it/s]
100%|██████████| 100/100 [00:00<00:00, 107.94it/s]
100%|██████████| 100/100 [00:00<00:00, 111.66it/s]
100%|██████████| 100/100 [00:00<00:00, 104.48it/s]
100%|██████████| 100/100 [00:00<00:00, 108.24it/s]
100%|██████████| 100/100 [00:01<00:00, 91.59it/s]
100%|██████████| 100/100 [00:00<00:0

               green_lights - Loss 0.7110 | Acc 0.5990
             follow_traffic - Loss 0.6930 | Acc 0.5550
                 road_clear - Loss 0.6931 | Acc 0.4960
             traffic_lights - Loss 0.6161 | Acc 0.6840
              traffic_signs - Loss 0.0350 | Acc 0.9900
                       cars - Loss 0.0008 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0388 | Acc 0.9870
                     others - Loss 0.0002 | Acc 1.0000
               no_lane_left - Loss 0.6931 | Acc 0.5070
         obstacle_left_lane - Loss 0.0065 | Acc 0.9990
            solid_left_line - Loss 0.0049 | Acc 1.0000
         on_right_turn_lane - Loss 0.7980 | Acc 0.6430
        traffic_light_right - Loss 0.6931 | Acc 0.4630
            front_car_right - Loss 1.3939 | Acc 0.5240
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5010
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:07 ] epoch 10: |██████████████████████████████████████████████████| loss: 2.71448493

  ACC C 57.1153011586931   ACC Y 58.06206597222222 F1 Y 41.55334004354282
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:01<00:00, 94.27it/s]
100%|██████████| 100/100 [00:01<00:00, 92.03it/s]
100%|██████████| 100/100 [00:01<00:00, 95.85it/s]
100%|██████████| 100/100 [00:00<00:00, 106.07it/s]
100%|██████████| 100/100 [00:00<00:00, 107.45it/s]
100%|██████████| 100/100 [00:00<00:00, 109.20it/s]
100%|██████████| 100/100 [00:00<00:00, 109.69it/s]
100%|██████████| 100/100 [00:00<00:00, 101.09it/s]
100%|██████████| 100/100 [00:00<00:00, 108.34it/s]
100%|██████████| 100/100 [00:00<00:00, 109.25it/s]
100%|██████████| 100/100 [00:00<00:00, 109.61it/s]
100%|██████████| 100/100 [00:00<00:00, 110.44it/s]
100%|██████████| 100/100 [00:00<00:00, 110.53it/s]
100%|██████████| 100/100 [00:00<00:00, 109.51it/s]
100%|██████████| 100/100 [00:00<00:00, 108.50it/s]
100%|██████████| 100/100 [00:00<00:00, 110.17it/s]
100%|██████████| 100/100 [00:00<00:00, 105.53it/s]
100%|██████████| 100/100 [00:00<00:00, 109.06it/s]
100%|██████████| 100/100 [00:00<00:00, 108.92it/s]
100%|██████████| 100/100 [00:00<00

               green_lights - Loss 0.6815 | Acc 0.5840
             follow_traffic - Loss 0.6931 | Acc 0.5520
                 road_clear - Loss 0.6931 | Acc 0.4990
             traffic_lights - Loss 0.3670 | Acc 0.8630
              traffic_signs - Loss 0.0382 | Acc 0.9920
                       cars - Loss 0.0004 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0203 | Acc 0.9960
                     others - Loss 0.0048 | Acc 0.9980
               no_lane_left - Loss 0.6931 | Acc 0.5020
         obstacle_left_lane - Loss 0.0036 | Acc 1.0000
            solid_left_line - Loss 0.0006 | Acc 1.0000
         on_right_turn_lane - Loss 0.9110 | Acc 0.6250
        traffic_light_right - Loss 0.6931 | Acc 0.4950
            front_car_right - Loss 1.2246 | Acc 0.5460
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5030
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:10 ] epoch 11: |██████████████████████████████████████████████████| loss: 2.72844195

  ACC C 54.960318737559845   ACC Y 57.51229745370371 F1 Y 41.27989070145826
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 101.39it/s]
100%|██████████| 100/100 [00:00<00:00, 102.31it/s]
100%|██████████| 100/100 [00:00<00:00, 103.24it/s]
100%|██████████| 100/100 [00:00<00:00, 102.47it/s]
100%|██████████| 100/100 [00:00<00:00, 101.76it/s]
100%|██████████| 100/100 [00:00<00:00, 102.05it/s]
100%|██████████| 100/100 [00:01<00:00, 99.97it/s]
100%|██████████| 100/100 [00:00<00:00, 100.69it/s]
100%|██████████| 100/100 [00:01<00:00, 98.66it/s]
100%|██████████| 100/100 [00:00<00:00, 100.95it/s]
100%|██████████| 100/100 [00:00<00:00, 100.93it/s]
100%|██████████| 100/100 [00:00<00:00, 101.43it/s]
100%|██████████| 100/100 [00:01<00:00, 99.92it/s]
100%|██████████| 100/100 [00:01<00:00, 98.90it/s]
100%|██████████| 100/100 [00:00<00:00, 101.23it/s]
100%|██████████| 100/100 [00:01<00:00, 96.42it/s]
100%|██████████| 100/100 [00:01<00:00, 97.32it/s]
100%|██████████| 100/100 [00:01<00:00, 99.70it/s]
100%|██████████| 100/100 [00:00<00:00, 100.24it/s]
100%|██████████| 100/100 [00:00<00:00,

               green_lights - Loss 0.6716 | Acc 0.6210
             follow_traffic - Loss 0.6931 | Acc 0.5370
                 road_clear - Loss 0.6931 | Acc 0.4980
             traffic_lights - Loss 0.3239 | Acc 0.8960
              traffic_signs - Loss 0.0378 | Acc 0.9880
                       cars - Loss 0.0003 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0065 | Acc 0.9980
                     others - Loss 0.0003 | Acc 1.0000
               no_lane_left - Loss 0.6931 | Acc 0.5080
         obstacle_left_lane - Loss 0.0015 | Acc 1.0000
            solid_left_line - Loss 0.0169 | Acc 0.9980
         on_right_turn_lane - Loss 0.8031 | Acc 0.6390
        traffic_light_right - Loss 0.6931 | Acc 0.4810
            front_car_right - Loss 1.0649 | Acc 0.5110
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5000
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:13 ] epoch 12: |██████████████████████████████████████████████████| loss: 2.78028512

  ACC C 55.95307101806005   ACC Y 58.22120949074073 F1 Y 42.25888128820598
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 119.48it/s]
100%|██████████| 100/100 [00:00<00:00, 116.69it/s]
100%|██████████| 100/100 [00:00<00:00, 119.92it/s]
100%|██████████| 100/100 [00:01<00:00, 88.06it/s]
100%|██████████| 100/100 [00:00<00:00, 102.31it/s]
100%|██████████| 100/100 [00:00<00:00, 105.04it/s]
100%|██████████| 100/100 [00:01<00:00, 96.68it/s]
100%|██████████| 100/100 [00:01<00:00, 86.93it/s]
100%|██████████| 100/100 [00:00<00:00, 115.56it/s]
100%|██████████| 100/100 [00:00<00:00, 115.53it/s]
100%|██████████| 100/100 [00:00<00:00, 115.63it/s]
100%|██████████| 100/100 [00:00<00:00, 115.46it/s]
100%|██████████| 100/100 [00:00<00:00, 115.35it/s]
100%|██████████| 100/100 [00:00<00:00, 115.58it/s]
100%|██████████| 100/100 [00:00<00:00, 115.56it/s]
100%|██████████| 100/100 [00:00<00:00, 115.83it/s]
100%|██████████| 100/100 [00:00<00:00, 115.65it/s]
100%|██████████| 100/100 [00:00<00:00, 115.55it/s]
100%|██████████| 100/100 [00:00<00:00, 115.29it/s]
100%|██████████| 100/100 [00:00<00

               green_lights - Loss 0.7869 | Acc 0.6430
             follow_traffic - Loss 0.6925 | Acc 0.6000
                 road_clear - Loss 0.6931 | Acc 0.5030
             traffic_lights - Loss 0.1009 | Acc 0.9650
              traffic_signs - Loss 0.0296 | Acc 0.9910
                       cars - Loss 0.0001 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0145 | Acc 0.9960
                     others - Loss 0.0002 | Acc 1.0000
               no_lane_left - Loss 0.6931 | Acc 0.5320
         obstacle_left_lane - Loss 0.0014 | Acc 1.0000
            solid_left_line - Loss 0.0015 | Acc 0.9990
         on_right_turn_lane - Loss 0.7690 | Acc 0.6450
        traffic_light_right - Loss 0.6931 | Acc 0.4910
            front_car_right - Loss 0.9659 | Acc 0.5250
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5000
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:16 ] epoch 13: |██████████████████████████████████████████████████| loss: 2.98752889

  ACC C 55.375745064682434   ACC Y 58.83608217592593 F1 Y 42.63313744682659
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 113.39it/s]
100%|██████████| 100/100 [00:00<00:00, 104.12it/s]
100%|██████████| 100/100 [00:00<00:00, 103.57it/s]
100%|██████████| 100/100 [00:00<00:00, 104.07it/s]
100%|██████████| 100/100 [00:00<00:00, 103.68it/s]
100%|██████████| 100/100 [00:00<00:00, 102.89it/s]
100%|██████████| 100/100 [00:00<00:00, 103.56it/s]
100%|██████████| 100/100 [00:00<00:00, 104.51it/s]
100%|██████████| 100/100 [00:00<00:00, 103.05it/s]
100%|██████████| 100/100 [00:00<00:00, 100.02it/s]
100%|██████████| 100/100 [00:00<00:00, 101.99it/s]
100%|██████████| 100/100 [00:00<00:00, 103.11it/s]
100%|██████████| 100/100 [00:00<00:00, 101.96it/s]
100%|██████████| 100/100 [00:00<00:00, 101.95it/s]
100%|██████████| 100/100 [00:00<00:00, 104.42it/s]
100%|██████████| 100/100 [00:00<00:00, 102.55it/s]
100%|██████████| 100/100 [00:00<00:00, 102.75it/s]
100%|██████████| 100/100 [00:00<00:00, 102.60it/s]
100%|██████████| 100/100 [00:00<00:00, 103.03it/s]
100%|██████████| 100/100 [00:01

               green_lights - Loss 0.6604 | Acc 0.6660
             follow_traffic - Loss 0.6926 | Acc 0.5660
                 road_clear - Loss 0.6931 | Acc 0.4950
             traffic_lights - Loss 0.0584 | Acc 0.9790
              traffic_signs - Loss 0.0238 | Acc 0.9930
                       cars - Loss 0.0008 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0026 | Acc 1.0000
                     others - Loss 0.0004 | Acc 1.0000
               no_lane_left - Loss 0.6931 | Acc 0.5570
         obstacle_left_lane - Loss 0.0354 | Acc 0.9860
            solid_left_line - Loss 0.0010 | Acc 1.0000
         on_right_turn_lane - Loss 0.8264 | Acc 0.6200
        traffic_light_right - Loss 0.6931 | Acc 0.5270
            front_car_right - Loss 1.4503 | Acc 0.5270
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5000
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:18 ] epoch 14: |██████████████████████████████████████████████████| loss: 2.69927359

  ACC C 56.007497012615204   ACC Y 59.75477430555556 F1 Y 44.576167305103795
----------------------------------
--- Prototypical Networks Training ---
Prototypical Networks Training Epoch 1/1


100%|██████████| 100/100 [00:00<00:00, 112.92it/s]
100%|██████████| 100/100 [00:00<00:00, 106.59it/s]
100%|██████████| 100/100 [00:01<00:00, 94.29it/s]
100%|██████████| 100/100 [00:01<00:00, 91.28it/s]
100%|██████████| 100/100 [00:01<00:00, 98.03it/s]
100%|██████████| 100/100 [00:01<00:00, 98.21it/s]
100%|██████████| 100/100 [00:01<00:00, 99.88it/s]
100%|██████████| 100/100 [00:01<00:00, 91.20it/s]
100%|██████████| 100/100 [00:00<00:00, 104.19it/s]
100%|██████████| 100/100 [00:00<00:00, 104.98it/s]
100%|██████████| 100/100 [00:00<00:00, 104.96it/s]
100%|██████████| 100/100 [00:00<00:00, 102.78it/s]
100%|██████████| 100/100 [00:00<00:00, 102.08it/s]
100%|██████████| 100/100 [00:01<00:00, 91.36it/s]
100%|██████████| 100/100 [00:01<00:00, 88.11it/s]
100%|██████████| 100/100 [00:01<00:00, 85.12it/s]
100%|██████████| 100/100 [00:01<00:00, 96.73it/s]
100%|██████████| 100/100 [00:00<00:00, 103.16it/s]
100%|██████████| 100/100 [00:00<00:00, 103.29it/s]
100%|██████████| 100/100 [00:00<00:00, 10

               green_lights - Loss 0.6838 | Acc 0.6400
             follow_traffic - Loss 0.6927 | Acc 0.6150
                 road_clear - Loss 0.6931 | Acc 0.4910
             traffic_lights - Loss 0.0462 | Acc 0.9890
              traffic_signs - Loss 0.0151 | Acc 0.9960
                       cars - Loss 0.0005 | Acc 1.0000
                pedestrians - Loss 0.6931 | Acc 0.5000
                     riders - Loss 0.0031 | Acc 1.0000
                     others - Loss 0.0005 | Acc 1.0000
               no_lane_left - Loss 0.6931 | Acc 0.5610
         obstacle_left_lane - Loss 0.0006 | Acc 1.0000
            solid_left_line - Loss 0.0010 | Acc 1.0000
         on_right_turn_lane - Loss 0.8034 | Acc 0.6480
        traffic_light_right - Loss 0.6931 | Acc 0.5530
            front_car_right - Loss 0.7163 | Acc 0.4850
              no_lane_right - Loss 0.6931 | Acc 0.5000
        obstacle_right_lane - Loss 0.6931 | Acc 0.5000
           solid_right_line - Loss 0.6931 | Acc 0.5000
          

[ 09-04 | 16:21 ] epoch 15: |██████████████████████████████████████████████████| loss: 2.75822139

  ACC C 62.414573629697166   ACC Y 59.450954861111114 F1 Y 44.30103823512157
Early stopping triggered after 16 epochs.

--- End of Training ---

*** Finished training model with seed 1 and best CACC score 63.46657458278868
Training finished.


# EVALUATION

In [19]:
# * Evaluate the model and save metrics
def evaluate_my_model(model: MnistDPL, 
        save_path: str, 
        test_loader: DataLoader,
        eval_concepts: List[str],
        args: Namespace,
        support_embeddings=None,
    ):
    
    if args.model == 'probddoiadpl':
        assert support_embeddings is not None, "Support embeddings must be provided for probddoiadpl model evaluation."
        my_metrics = evaluate_metrics(model, test_loader, args,
                        support_emb_dict=support_embeddings, 
                        eval_concepts=eval_concepts
        )
    else:
        my_metrics = evaluate_metrics(model, test_loader, args, 
                        eval_concepts=eval_concepts
        )

    loss = my_metrics[0]
    cacc = my_metrics[1]
    yacc = my_metrics[2]
    f1_y = my_metrics[3]
    f1_micro = my_metrics[4]
    f1_weight = my_metrics[5]
    f1_bin = my_metrics[6]

    metrics_log_path = save_path.replace(".pth", "_metrics.log")
    
    all_concepts = [ 'Green Traffic Light', 'Follow Traffic', 'Road Is Clear',
        'Red Traffic Light', 'Traffic Sign', 'Obstacle Car', 'Obstacle Pedestrian', 'Obstacle Rider', 'Obstacle Others',
        'No Lane On The Left',  'Obstacle On The Left Lane',  'Solid Left Line',
                'On The Right Turn Lane', 'Traffic Light Allows Right', 'Front Car Turning Right', 
        'No Lane On The Right', 'Obstacle On The Right Lane', 'Solid Right Line',
                'On The Left Turn Lane',  'Traffic Light Allows Left',  'Front Car Turning Left' 
    ]
    aggregated_metrics = [
            'F1 - Binary', 'F1 - Macro', 'F1 - Micro', 'F1 - Weighted',
            'Precision - Binary', 'Precision - Macro', 'Precision - Micro', 'Precision - Weighted',
            'Recall - Binary', 'Recall - Macro', 'Recall - Micro', 'Recall - Weighted',
            'Balanced Accuracy'
    ]

    sums = [0.0] * len(aggregated_metrics)
    num_concepts = len(all_concepts)
    with open(metrics_log_path, "a") as log_file:
        log_file.write(f"ACC C: {cacc}, ACC Y: {yacc}\n\n")
        log_file.write(f"F1 Y - Macro: {f1_y}, F1 Y - Micro: {f1_micro}, F1 Y - Weighted: {f1_weight}, F1 Y - Binary: {f1_bin} \n\n")

        def write_metrics(class_name, offset):
            print(f"Reporting Metrics for {class_name} in {metrics_log_path}")
            log_file.write(f"{class_name.upper()}\n")
            for idx, metric_name in enumerate(aggregated_metrics):
                value = my_metrics[offset + idx]
                sums[idx] += value
                log_file.write(f"  {metric_name:<18} {value:.4f}\n")
            log_file.write("\n")

        i = 7
        for concept in all_concepts:
            write_metrics(concept, i)
            i += len(aggregated_metrics)

        log_file.write("**MEAN ACROSS ALL CONCEPTS**\n")
        for idx, metric_name in enumerate(aggregated_metrics):
            mean_value = sums[idx] / num_concepts
            log_file.write(f"  {metric_name:<18} {mean_value:.4f}\n")
        log_file.write("\n")


    assert len(my_metrics) == 7 + len(all_concepts) * len(aggregated_metrics), \
        f"Expected {7 + len(all_concepts) * len(aggregated_metrics)} metrics, but got {len(my_metrics)}"
    
    if args.model == 'probddoiadpl':
        y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
            evaluate_metrics(model, test_loader, args,
                        support_emb_dict=support_embeddings, 
                        eval_concepts=eval_concepts,
                        last=True
                )
        )
    else:
        y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
            evaluate_metrics(model, test_loader, args,
                        eval_concepts=eval_concepts,
                        last=True
                )
        )
    
    y_labels = ["stop", "forward", "left", "right"]
    concept_labels = [
        "green_light",      
        "follow",           
        "road_clear",       
        "red_light",        
        "traffic_sign",     
        "car",              
        "person",           
        "rider",            
        "other_obstacle",   
        "left_lane",
        "left_green_light",
        "left_follow",
        "no_left_lane",
        "left_obstacle",
        "letf_solid_line",
        "right_lane",
        "right_green_light",
        "right_follow",
        "no_right_lane",
        "right_obstacle",
        "right_solid_line",
    ]

    plot_multilabel_confusion_matrix(y_true, y_pred, y_labels, "Labels", save_path=save_path)
    cfs = plot_actions_confusion_matrix(c_true, c_pred, "Concepts", save_path=save_path)
    cf = plot_multilabel_confusion_matrix(c_true, c_pred, concept_labels, "Concepts", save_path=save_path)
    print("Concept collapse", 1 - compute_coverage(cf))

    with open(metrics_log_path, "a") as log_file:
        for key, value in cfs.items():
            log_file.write(f"Concept collapse: {key}, {1 - compute_coverage(value):.4f}\n")
            log_file.write("\n")

    fprint("\n--- End of Evaluation ---\n")

## Run Evaluation

In [20]:
# Initialize the model object
model = get_model(args, encoder, decoder, n_images, c_split)

# Load the model state dictionary into the model object
model_state_dict = torch.load(save_folder)
model.load_state_dict(model_state_dict)

# Evaluate the model
evaluate_my_model(
    model=model, 
    save_path=save_folder, 
    test_loader=unsup_test_loader,
    eval_concepts=['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
        'on_right_turn_lane', 'traffic_light_right', 'front_car_right',
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
        'on_left_turn_lane', 'traffic_light_left', 'front_car_left'],
    args=args,
    support_embeddings=support_emb_dict
)

Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 'probddoiadpl', 'mnistpcbmsl', 'mnistpcbmltn', 'kanddplsingledisj', 'mnistsl', 'kandslsinglejoint', 'mnistdplrec', 'cvae', 'cext', 'mnmathnn', 'promnistsl']
Reporting Metrics for Green Traffic Light in ../notebook-outputs/bddoia/my_models/dpl/episodic-proto-net-pipeline-1.0-PROVA/dpl_1_metrics.log
Reporting Metrics for Follow Traffic in ../noteb