# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
sys.path.append(os.path.abspath(".."))       # for 'protonet_STOP_bddoia_modules' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

['/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia/notebooks', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python38.zip', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/lib-dynload', '', '/users-1/eleonora/.local/lib/python3.8/site-packages', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/site-packages', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation', '/users-1/eleonora/reasoning-shortcuts/IXShort']


In [2]:
import csv
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import datetime
import numpy as np
import setproctitle, socket, uuid

from models import get_model
from models.mnistdpl import MnistDPL
from datasets import get_dataset

from argparse import Namespace
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix

from utils import fprint
from utils.status import progress_bar
from utils.metrics import evaluate_metrics
from utils.dpl_loss import ADDMNIST_DPL
from utils.checkpoint import save_model
from torch.utils.data import Dataset, DataLoader

from warmup_scheduler import GradualWarmupScheduler
from baseline_modules.arguments import args_dpl 
from backbones.bddoia_protonet import ProtoNetConv1D, PrototypicalLoss
from protonet_STOP_bddoia_modules.proto_modules.proto_helpers import assert_inputs

# SETUP

In [3]:
SEED = 0
UNS_PERCENTAGE = 1.0

In [4]:
args = args_dpl
args.seed = SEED

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "bddoia" 
save_model_name = 'dpl'
save_paths = []
save_path = os.path.join("..",
    "NEW-outputs", 
    save_folder, 
    "baseline", 
    save_model_name,
    f"baseline-disj"
)
save_paths.append(save_path)

print("Seed: " + str(args.seed))
print(f"Save paths: {str(save_paths)}")

Seed: 0
Save paths: ['../NEW-outputs/bddoia/baseline/dpl/baseline-disj']


# UTILS

## Test Set Evaluation (alternative to full fledged notebook eval)

In [5]:

# * helper function for 'plot_multilabel_confusion_matrix'
def convert_to_categories(elements):
    # Convert vector of 0s and 1s to a single binary representation along the first dimension
    binary_rep = np.apply_along_axis(
        lambda x: "".join(map(str, x)), axis=1, arr=elements
    )
    return np.array([int(x, 2) for x in binary_rep])


# * BBDOIA custom confusion matrix for concepts
def plot_multilabel_confusion_matrix(
    y_true, y_pred, class_names, title, save_path=None
):
    y_true_categories = convert_to_categories(y_true.astype(int))
    y_pred_categories = convert_to_categories(y_pred.astype(int))

    to_rtn_cm = confusion_matrix(y_true_categories, y_pred_categories)

    cm = multilabel_confusion_matrix(y_true, y_pred)
    num_classes = len(class_names)
    num_rows = (num_classes + 4) // 5  # Calculate the number of rows needed

    plt.figure(figsize=(20, 4 * num_rows))  # Adjust the figure size

    for i in range(num_classes):
        plt.subplot(num_rows, 5, i + 1)  # Set the subplot position
        plt.imshow(cm[i], interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"Class: {class_names[i]}")
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ["0", "1"])
        plt.yticks(tick_marks, ["0", "1"])

        fmt = ".0f"
        thresh = cm[i].max() / 2.0
        for j in range(cm[i].shape[0]):
            for k in range(cm[i].shape[1]):
                plt.text(
                    k,
                    j,
                    format(cm[i][j, k], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i][j, k] > thresh else "black",
                )

        plt.ylabel("True label")
        plt.xlabel("Predicted label")

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.suptitle(title)

    if save_path:
        plt.savefig(f"{save_path}_total.png")
    else:
        plt.show()

    plt.close()

    return to_rtn_cm


# * Concept collapse (Soft)
def compute_coverage(confusion_matrix):
    """Compute the coverage of a confusion matrix.

    Essentially this metric is
    """

    max_values = np.max(confusion_matrix, axis=0)
    clipped_values = np.clip(max_values, 0, 1)

    # Redefinition of soft coverage
    coverage = np.sum(clipped_values) / len(clipped_values)

    return coverage


# * BDDOIA custom confusion matrix for actions
def plot_actions_confusion_matrix(c_true, c_pred, title, save_path=None):

    # Define scenarios and corresponding labels
    my_scenarios = {
        "forward": [slice(0, 3), slice(0, 3)],  
        "stop": [slice(3, 9), slice(3, 9)],
        "left": [slice(9, 11), slice(18,20)],
        "right": [slice(12, 17), slice(12,17)],
    }

    to_rtn = {}

    # Plot confusion matrix for each scenario
    for scenario, indices in my_scenarios.items():

        g_true = convert_to_categories(c_true[:, indices[0]].astype(int))
        c_pred_scenario = convert_to_categories(c_pred[:, indices[1]].astype(int))

        # Compute confusion matrix
        cm = confusion_matrix(g_true, c_pred_scenario)

        # Plot confusion matrix
        plt.figure()
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"{title} - {scenario}")
        plt.colorbar()

        n_classes = c_true[:, indices[0]].shape[1]

        tick_marks = np.arange(2**n_classes)
        plt.xticks(tick_marks, ["" for _ in range(len(tick_marks))])
        plt.yticks(tick_marks, ["" for _ in range(len(tick_marks))])

        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.tight_layout()

        # Save or show plot
        if save_path:
            plt.savefig(f"{save_path}_{scenario}.png")
        else:
            plt.show()

        to_rtn.update({scenario: cm})

        plt.close()

    return to_rtn

# Other Utils

In [6]:

# * method used to check if all encoder parameters are registered in the model's optimizer
def check_optimizer_params(model):
    """Check that all encoder parameters are registered in the optimizer."""
    # Get all encoder parameters
    encoder_params = []
    for i in range(21):
        encoder = model.encoder[i]
        for name, param in encoder.named_parameters():
            if not param.requires_grad:
                continue  # skip frozen params
            encoder_params.append((f"encoder_{i}.{name}", param))

    # Get all parameters in the optimizer
    opt_param_ids = set(id(p) for group in model.opt.param_groups for p in group['params'])

    # Check each encoder param is in the optimizer
    missing = [(name, p.shape) for name, p in encoder_params if id(p) not in opt_param_ids]

    if missing:
        print("⚠️ The following parameters are missing from the optimizer:")
        for name, shape in missing:
            print(f"  - {name}: {shape}")
        raise RuntimeError("Some encoder parameters are not registered in the optimizer.")
    else:
        print("✅ All encoder parameters are correctly registered in the optimizer.")

# DATA LOADING

In [7]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()

encoder, decoder = dataset.get_backbone()
assert isinstance(encoder, tuple) and len(encoder) == 21, "encoder must be a tuple of 21 elements"

model = get_model(args, encoder, decoder, n_images, c_split)
model.start_optim(args)
check_optimizer_params(model)
loss = model.get_loss(args)

print(dataset)
print("Using Dataset: ", dataset)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)

unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders(args=args)

Available datasets: ['mnmath', 'xor', 'clipboia', 'shortmnist', 'restrictedmnist', 'minikandinsky', 'presddoia', 'prekandinsky', 'sddoia', 'clipkandinsky', 'addmnist', 'clipshortmnist', 'boia_original', 'boia_original_embedded', 'clipsddoia', 'boia', 'kandinsky', 'halfmnist']
Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 'probddoiadpl', 'mnistpcbmsl', 'mnistpcbmltn', 'kanddplsingledisj', 'm

# TRAINING

## Main Loop

In [10]:
def train(
        model: MnistDPL, 
        _loss: ADDMNIST_DPL,
        save_path: str,
        train_loader: DataLoader,
        val_loader: DataLoader,
        args: Namespace,
        eval_concepts: list = None,
        seed: int = 0,
    ) -> float:

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    # early stopping
    best_cacc = 0.0
    epochs_no_improve = 0
    
    # scheduler & warmup (not used) for main model
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:
        w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")
    model.to(model.device)
    for b in range(len(model.encoder)):
        model.encoder[b].train()
        model.encoder[b].to(model.device)
    
    model.opt.zero_grad()
    model.opt.step()

    # & Training start
    for epoch in range(args.n_epochs):
        print(f"Epoch {epoch+1}/{args.n_epochs}")
        
        model.train()

        # * Unsupervised Training
        print("Unsupervised training phase")
        ys, y_true, cs, cs_true, batch = None, None, None, None, 0
        for i, batch in enumerate(train_loader):
            
            # ------------------ original embneddings
            images_embeddings = torch.stack(batch['embeddings']).to(model.device)
            attr_labels = torch.stack(batch['attr_labels']).to(model.device)
            class_labels = torch.stack(batch['class_labels'])[:,:-1].to(model.device)
            # ------------------ my extracted features
            images_embeddings_raw = torch.stack(batch['embeddings_raw']).to(model.device)
            detected_rois = batch['rois']
            detected_rois_feats = batch['roi_feats']
            detection_labels = batch['detection_labels']
            detection_scores = batch['detection_scores']
            assert_inputs(images_embeddings, attr_labels, class_labels,
                   detected_rois_feats, detected_rois, detection_labels,
                   detection_scores, images_embeddings_raw)

            out_dict = model(images_embeddings_raw)
            out_dict.update({"LABELS": class_labels, "CONCEPTS": attr_labels})
            
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)

            loss.backward()
            model.opt.step()

            if ys is None:
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(train_loader) - 9, epoch, loss.item())
            
        # ^ Evaluation phase
        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        model.eval()
        
        my_metrics = evaluate_metrics(
                model=model, 
                loader=val_loader, 
                args=args,
                eval_concepts=eval_concepts)
        loss = my_metrics[0]
        cacc = my_metrics[1]
        yacc = my_metrics[2]
        f1_y = my_metrics[3]
       
        # update at end of the epoch
        if epoch < args.warmup_steps:   w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1_y)
        
        if not args.tuning and cacc > best_cacc:
            print("Saving...")
            # Update best F1 score
            best_cacc = cacc
            epochs_no_improve = 0

            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with CACC score: {best_cacc}")

        elif cacc <= best_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break


    fprint("\n--- End of Training ---\n")
    return best_cacc

## Run Training

In [11]:
print(f"*** Training model with seed {args.seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{args.seed}.pth")
print("Saving model in folder: ", save_folder)

eval_concepts = ['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
                'on_right_turn_lane', 'traffic_light_right', 'front_car_right', 
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
                'on_left_turn_lane', 'traffic_light_left', 'front_car_left']

best_cacc = train(
        model=model,
        train_loader=unsup_train_loader,
        val_loader=unsup_val_loader,
        save_path=save_folder,
        _loss=loss,
        args=args,
        eval_concepts=eval_concepts,
        seed=args.seed,
)
save_model(model, args, args.seed)  # save the model parameters
print(f"*** Finished training model with seed {args.seed} and best CACC score {best_cacc}")

print("Training finished.")

*** Training model with seed 0
Chosen device: cuda
Saving model in folder:  ../NEW-outputs/bddoia/baseline/dpl/baseline-disj/dpl_0.pth

--- Start of Training ---

Epoch 1/40
Unsupervised training phase


[ 09-04 | 16:31 ] epoch 0: |██████████████████████████████████████████████████| loss: 4.73461485

  ACC C 58.644857340388825   ACC Y 74.6004971590909 F1 Y 49.453690077617956
Saving...
Saved best model with CACC score: 58.644857340388825
Epoch 2/40
Unsupervised training phase


[ 09-04 | 16:31 ] epoch 1: |██████████████████████████████████████████████████| loss: 4.74198675

  ACC C 58.044170671039154   ACC Y 75.79407354797979 F1 Y 53.85242131449618
Epoch 3/40
Unsupervised training phase


[ 09-04 | 16:32 ] epoch 2: |██████████████████████████████████████████████████| loss: 4.68915987

  ACC C 56.68440494272444   ACC Y 73.42369002525253 F1 Y 59.27764812560896
Epoch 4/40
Unsupervised training phase


[ 09-04 | 16:32 ] epoch 3: |██████████████████████████████████████████████████| loss: 4.66030788

  ACC C 60.39543516106076   ACC Y 75.89863478535354 F1 Y 58.19490986107035
Saving...
Saved best model with CACC score: 60.39543516106076
Epoch 5/40
Unsupervised training phase


[ 09-04 | 16:33 ] epoch 4: |██████████████████████████████████████████████████| loss: 4.67436743

  ACC C 67.630735039711   ACC Y 75.84734059343435 F1 Y 56.24353409812582
Saving...
Saved best model with CACC score: 67.630735039711
Epoch 6/40
Unsupervised training phase


[ 09-04 | 16:34 ] epoch 5: |██████████████████████████████████████████████████| loss: 4.68581676

  ACC C 67.7808599339591   ACC Y 76.43031881313132 F1 Y 60.242873132086984
Saving...
Saved best model with CACC score: 67.7808599339591
Epoch 7/40
Unsupervised training phase


[ 09-04 | 16:34 ] epoch 6: |██████████████████████████████████████████████████| loss: 4.72399959

  ACC C 66.82224074999492   ACC Y 76.12354008838383 F1 Y 59.285132949997546
Epoch 8/40
Unsupervised training phase


[ 09-04 | 16:35 ] epoch 7: |██████████████████████████████████████████████████| loss: 4.64077759

  ACC C 66.67211618688371   ACC Y 75.6747159090909 F1 Y 55.4908746055696
Epoch 9/40
Unsupervised training phase


[ 09-04 | 16:35 ] epoch 8: |██████████████████████████████████████████████████| loss: 4.67003012

  ACC C 67.292344239023   ACC Y 75.83057133838383 F1 Y 64.03060156270601
Epoch 10/40
Unsupervised training phase


[ 09-04 | 16:36 ] epoch 9: |██████████████████████████████████████████████████| loss: 4.65805292

  ACC C 67.64539149072435   ACC Y 76.42735953282828 F1 Y 62.885549925869945
Epoch 11/40
Unsupervised training phase


[ 09-04 | 16:37 ] epoch 10: |██████████████████████████████████████████████████| loss: 4.63746262

  ACC C 69.52411234378815   ACC Y 75.9716303661616 F1 Y 59.14573604488898
Saving...
Saved best model with CACC score: 69.52411234378815
Epoch 12/40
Unsupervised training phase


[ 09-04 | 16:37 ] epoch 11: |██████████████████████████████████████████████████| loss: 4.66167879

  ACC C 68.86855993005965   ACC Y 76.33562184343435 F1 Y 64.39736529599861
Epoch 13/40
Unsupervised training phase


[ 09-04 | 16:38 ] epoch 12: |██████████████████████████████████████████████████| loss: 4.66995478

  ACC C 65.54533541202545   ACC Y 76.11762152777777 F1 Y 63.9394179692482
Epoch 14/40
Unsupervised training phase


[ 09-04 | 16:38 ] epoch 13: |██████████████████████████████████████████████████| loss: 4.62524366

  ACC C 66.32339093420241   ACC Y 76.49739583333333 F1 Y 63.821181860758834
Epoch 15/40
Unsupervised training phase


[ 09-04 | 16:39 ] epoch 14: |██████████████████████████████████████████████████| loss: 4.63554621

  ACC C 63.63298296928406   ACC Y 75.96373895202021 F1 Y 61.951133101421455
Epoch 16/40
Unsupervised training phase


[ 09-04 | 16:40 ] epoch 15: |██████████████████████████████████████████████████| loss: 4.67619324

  ACC C 65.91453982724084   ACC Y 76.18173926767676 F1 Y 60.0940706833506
Epoch 17/40
Unsupervised training phase


[ 09-04 | 16:40 ] epoch 16: |██████████████████████████████████████████████████| loss: 4.70419645

  ACC C 65.78170160452525   ACC Y 75.89172979797979 F1 Y 55.524477213492766
Epoch 18/40
Unsupervised training phase


[ 09-04 | 16:41 ] epoch 17: |██████████████████████████████████████████████████| loss: 4.68155861

  ACC C 69.24227509233687   ACC Y 76.25473484848484 F1 Y 61.29554026277832
Epoch 19/40
Unsupervised training phase


[ 09-04 | 16:41 ] epoch 18: |██████████████████████████████████████████████████| loss: 4.60089827

  ACC C 65.84351824389563   ACC Y 75.91441761363636 F1 Y 58.94216893664639
Epoch 20/40
Unsupervised training phase


[ 09-04 | 16:42 ] epoch 19: |██████████████████████████████████████████████████| loss: 4.65861328

  ACC C 65.51715201801724   ACC Y 75.90849905303031 F1 Y 63.68523085584052
Epoch 21/40
Unsupervised training phase


[ 09-04 | 16:43 ] epoch 20: |██████████████████████████████████████████████████| loss: 4.62103367

  ACC C 66.50940279165904   ACC Y 76.50824652777777 F1 Y 58.92007348393876
Epoch 22/40
Unsupervised training phase


[ 09-04 | 16:43 ] epoch 21: |██████████████████████████████████████████████████| loss: 4.68025208

  ACC C 64.51005670759413   ACC Y 76.26755839646465 F1 Y 60.32646457004161
Epoch 23/40
Unsupervised training phase


[ 09-04 | 16:44 ] epoch 22: |██████████████████████████████████████████████████| loss: 4.70934772

  ACC C 64.23705187108781   ACC Y 76.49640940656565 F1 Y 61.56927581036402
Epoch 24/40
Unsupervised training phase


[ 09-04 | 16:44 ] epoch 23: |██████████████████████████████████████████████████| loss: 4.64116716

  ACC C 69.14814180798001   ACC Y 76.21527777777777 F1 Y 64.36664940377224
Epoch 25/40
Unsupervised training phase


[ 09-04 | 16:45 ] epoch 24: |██████████████████████████████████████████████████| loss: 4.60145092

  ACC C 67.43852264351315   ACC Y 76.19653566919192 F1 Y 62.628680392970494
Epoch 26/40
Unsupervised training phase


[ 09-04 | 16:46 ] epoch 25: |██████████████████████████████████████████████████| loss: 4.64889336

  ACC C 65.4250853591495   ACC Y 76.21626420454545 F1 Y 60.77633240278068
Early stopping triggered after 26 epochs.

--- End of Training ---

*** Finished training model with seed 0 and best CACC score 69.52411234378815
Training finished.


# TESTING

## Evaluation Routine

In [12]:
def evaluate_my_model(model: MnistDPL, 
        save_path: str, 
        test_loader: DataLoader,
        eval_concepts,
    ):
    
    my_metrics = evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,)
    
    loss = my_metrics[0]
    cacc = my_metrics[1]
    yacc = my_metrics[2]
    f1_y = my_metrics[3]
    f1_micro = my_metrics[4]
    f1_weight = my_metrics[5]
    f1_bin = my_metrics[6]

    metrics_log_path = save_path.replace(".pth", "_metrics.log")
    
    all_concepts = [ 'Green Traffic Light', 'Follow Traffic', 'Road Is Clear',
        'Red Traffic Light', 'Traffic Sign', 'Obstacle Car', 'Obstacle Pedestrian', 'Obstacle Rider', 'Obstacle Others',
        'No Lane On The Left',  'Obstacle On The Left Lane',  'Solid Left Line',
                'On The Right Turn Lane', 'Traffic Light Allows Right', 'Front Car Turning Right', 
        'No Lane On The Right', 'Obstacle On The Right Lane', 'Solid Right Line',
                'On The Left Turn Lane',  'Traffic Light Allows Left',  'Front Car Turning Left' 
    ]
    aggregated_metrics = [
            'F1 - Binary', 'F1 - Macro', 'F1 - Micro', 'F1 - Weighted',
            'Precision - Binary', 'Precision - Macro', 'Precision - Micro', 'Precision - Weighted',
            'Recall - Binary', 'Recall - Macro', 'Recall - Micro', 'Recall - Weighted',
            'Balanced Accuracy'
    ]

    sums = [0.0] * len(aggregated_metrics)
    num_concepts = len(all_concepts)
    with open(metrics_log_path, "a") as log_file:
        log_file.write(f"ACC C: {cacc}, ACC Y: {yacc}\n\n")
        log_file.write(f"F1 Y - Macro: {f1_y}, F1 Y - Micro: {f1_micro}, F1 Y - Weighted: {f1_weight}, F1 Y - Binary: {f1_bin} \n\n")

        def write_metrics(class_name, offset):
            print(f"Metrics for {class_name}:")
            log_file.write(f"{class_name.upper()}\n")
            for idx, metric_name in enumerate(aggregated_metrics):
                value = my_metrics[offset + idx]
                sums[idx] += value
                log_file.write(f"  {metric_name:<18} {value:.4f}\n")
            log_file.write("\n")

        i = 7
        for concept in all_concepts:
            write_metrics(concept, i)
            i += len(aggregated_metrics)

        log_file.write("MEAN ACROSS ALL CONCEPTS\n")
        for idx, metric_name in enumerate(aggregated_metrics):
            mean_value = sums[idx] / num_concepts
            log_file.write(f"  {metric_name:<18} {mean_value:.4f}\n")
        log_file.write("\n")


    assert len(my_metrics) == 7 + len(all_concepts) * len(aggregated_metrics), \
        f"Expected {7 + len(all_concepts) * len(aggregated_metrics)} metrics, but got {len(my_metrics)}"
    y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
        evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,
                    last=True
            )
    )
    y_labels = ["stop", "forward", "left", "right"]
    concept_labels = [
        "green_light",      
        "follow",           
        "road_clear",       
        "red_light",        
        "traffic_sign",     
        "car",              
        "person",           
        "rider",            
        "other_obstacle",   
        "left_lane",
        "left_green_light",
        "left_follow",
        "no_left_lane",
        "left_obstacle",
        "letf_solid_line",
        "right_lane",
        "right_green_light",
        "right_follow",
        "no_right_lane",
        "right_obstacle",
        "right_solid_line",
    ]

    plot_multilabel_confusion_matrix(y_true, y_pred, y_labels, "Labels", save_path=save_path)
    cfs = plot_actions_confusion_matrix(c_true, c_pred, "Concepts", save_path=save_path)
    cf = plot_multilabel_confusion_matrix(c_true, c_pred, concept_labels, "Concepts", save_path=save_path)
    print("Concept collapse", 1 - compute_coverage(cf))

    with open(metrics_log_path, "a") as log_file:
        for key, value in cfs.items():
            log_file.write(f"Concept collapse: {key}, {1 - compute_coverage(value):.4f}\n")
            log_file.write("\n")

    fprint("\n--- End of Evaluation ---\n")

## Run Evaluation

In [13]:
# Initialize the model object
model = get_model(args, encoder, decoder, n_images, c_split)

# Load the model state dictionary into the model object
model_state_dict = torch.load(save_folder)
model.load_state_dict(model_state_dict)

evaluate_my_model(model, save_folder, unsup_test_loader, eval_concepts=eval_concepts)

Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 'probddoiadpl', 'mnistpcbmsl', 'mnistpcbmltn', 'kanddplsingledisj', 'mnistsl', 'kandslsinglejoint', 'mnistdplrec', 'cvae', 'cext', 'mnmathnn', 'promnistsl']
Metrics for Green Traffic Light:
Metrics for Follow Traffic:
Metrics for Road Is Clear:
Metrics for Red Traffic Light:
Metrics for Traffic Sign:
Metrics for Obstacle Car:
Metrics for Obstacl