# IMPORTS

In [1]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
sys.path.append(os.path.abspath(".."))       # for 'protonet_STOP_bddoia_modules' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

['/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia/notebooks', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python38.zip', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/lib-dynload', '', '/users-1/eleonora/.local/lib/python3.8/site-packages', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/site-packages', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/bddoia', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation', '/users-1/eleonora/reasoning-shortcuts/IXShort']


In [2]:
import csv
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import datetime
import numpy as np
import setproctitle, socket, uuid

from models import get_model
from models.mnistdpl import MnistDPL
from datasets import get_dataset

from argparse import Namespace
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix

from utils import fprint
from utils.status import progress_bar
from utils.metrics import evaluate_metrics
from utils.dpl_loss import ADDMNIST_DPL
from utils.checkpoint import save_model
from torch.utils.data import Dataset, DataLoader

from warmup_scheduler import GradualWarmupScheduler
from baseline_modules.arguments import args_dpl 
from backbones.bddoia_protonet import ProtoNetConv1D, PrototypicalLoss
from protonet_STOP_bddoia_modules.proto_modules.proto_helpers import assert_inputs

# SETUP

In [3]:
SEED = 0
UNS_PERCENTAGE = 1.0

In [4]:
args = args_dpl
args.seed = SEED

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "bddoia" 
save_model_name = 'dpl'
save_paths = []
save_path = os.path.join("..",
    "NEW-outputs", 
    save_folder, 
    "baseline", 
    save_model_name,
    f"baseline-disj-PRE"
)
save_paths.append(save_path)

print("Seed: " + str(args.seed))
print(f"Save paths: {str(save_paths)}")

Seed: 0
Save paths: ['../NEW-outputs/bddoia/baseline/dpl/baseline-disj-PRE']


# UTILS

## Test Set Evaluation (alternative to full fledged notebook eval)

In [5]:

# * helper function for 'plot_multilabel_confusion_matrix'
def convert_to_categories(elements):
    # Convert vector of 0s and 1s to a single binary representation along the first dimension
    binary_rep = np.apply_along_axis(
        lambda x: "".join(map(str, x)), axis=1, arr=elements
    )
    return np.array([int(x, 2) for x in binary_rep])


# * BBDOIA custom confusion matrix for concepts
def plot_multilabel_confusion_matrix(
    y_true, y_pred, class_names, title, save_path=None
):
    y_true_categories = convert_to_categories(y_true.astype(int))
    y_pred_categories = convert_to_categories(y_pred.astype(int))

    to_rtn_cm = confusion_matrix(y_true_categories, y_pred_categories)

    cm = multilabel_confusion_matrix(y_true, y_pred)
    num_classes = len(class_names)
    num_rows = (num_classes + 4) // 5  # Calculate the number of rows needed

    plt.figure(figsize=(20, 4 * num_rows))  # Adjust the figure size

    for i in range(num_classes):
        plt.subplot(num_rows, 5, i + 1)  # Set the subplot position
        plt.imshow(cm[i], interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"Class: {class_names[i]}")
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ["0", "1"])
        plt.yticks(tick_marks, ["0", "1"])

        fmt = ".0f"
        thresh = cm[i].max() / 2.0
        for j in range(cm[i].shape[0]):
            for k in range(cm[i].shape[1]):
                plt.text(
                    k,
                    j,
                    format(cm[i][j, k], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i][j, k] > thresh else "black",
                )

        plt.ylabel("True label")
        plt.xlabel("Predicted label")

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.suptitle(title)

    if save_path:
        plt.savefig(f"{save_path}_total.png")
    else:
        plt.show()

    plt.close()

    return to_rtn_cm


# * Concept collapse (Soft)
def compute_coverage(confusion_matrix):
    """Compute the coverage of a confusion matrix.

    Essentially this metric is
    """

    max_values = np.max(confusion_matrix, axis=0)
    clipped_values = np.clip(max_values, 0, 1)

    # Redefinition of soft coverage
    coverage = np.sum(clipped_values) / len(clipped_values)

    return coverage


# * BDDOIA custom confusion matrix for actions
def plot_actions_confusion_matrix(c_true, c_pred, title, save_path=None):

    # Define scenarios and corresponding labels
    my_scenarios = {
        "forward": [slice(0, 3), slice(0, 3)],  
        "stop": [slice(3, 9), slice(3, 9)],
        "left": [slice(9, 11), slice(18,20)],
        "right": [slice(12, 17), slice(12,17)],
    }

    to_rtn = {}

    # Plot confusion matrix for each scenario
    for scenario, indices in my_scenarios.items():

        g_true = convert_to_categories(c_true[:, indices[0]].astype(int))
        c_pred_scenario = convert_to_categories(c_pred[:, indices[1]].astype(int))

        # Compute confusion matrix
        cm = confusion_matrix(g_true, c_pred_scenario)

        # Plot confusion matrix
        plt.figure()
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"{title} - {scenario}")
        plt.colorbar()

        n_classes = c_true[:, indices[0]].shape[1]

        tick_marks = np.arange(2**n_classes)
        plt.xticks(tick_marks, ["" for _ in range(len(tick_marks))])
        plt.yticks(tick_marks, ["" for _ in range(len(tick_marks))])

        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.tight_layout()

        # Save or show plot
        if save_path:
            plt.savefig(f"{save_path}_{scenario}.png")
        else:
            plt.show()

        to_rtn.update({scenario: cm})

        plt.close()

    return to_rtn

# Other Utils

In [6]:

# * method used to check if all encoder parameters are registered in the model's optimizer
def check_optimizer_params(model):
    """Check that all encoder parameters are registered in the optimizer."""
    # Get all encoder parameters
    encoder_params = []
    for i in range(21):
        encoder = model.encoder[i]
        for name, param in encoder.named_parameters():
            if not param.requires_grad:
                continue  # skip frozen params
            encoder_params.append((f"encoder_{i}.{name}", param))

    # Get all parameters in the optimizer
    opt_param_ids = set(id(p) for group in model.opt.param_groups for p in group['params'])

    # Check each encoder param is in the optimizer
    missing = [(name, p.shape) for name, p in encoder_params if id(p) not in opt_param_ids]

    if missing:
        print("⚠️ The following parameters are missing from the optimizer:")
        for name, shape in missing:
            print(f"  - {name}: {shape}")
        raise RuntimeError("Some encoder parameters are not registered in the optimizer.")
    else:
        print("✅ All encoder parameters are correctly registered in the optimizer.")

In [7]:
args.model

'bddoiadpldisj'

# DATA LOADING

In [8]:
args.model

'bddoiadpldisj'

In [9]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()

encoder, decoder = dataset.get_backbone()
assert isinstance(encoder, tuple) and len(encoder) == 21, "encoder must be a tuple of 21 elements"

model = get_model(args, encoder, decoder, n_images, c_split)
model.start_optim(args)
check_optimizer_params(model)
loss = model.get_loss(args)

print(dataset)
print("Using Dataset: ", dataset)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)

unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders(args=args)

Available datasets: ['mnmath', 'xor', 'clipboia', 'shortmnist', 'restrictedmnist', 'minikandinsky', 'presddoia', 'prekandinsky', 'sddoia', 'clipkandinsky', 'addmnist', 'clipshortmnist', 'boia_original', 'boia_original_embedded', 'clipsddoia', 'boia', 'kandinsky', 'halfmnist']
Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 'probddoiadpl', 'mnistpcbmsl', 'mnistpcbmltn', 'kanddplsingledisj', 'm

# PROTOTYPES CONSTRUCTION

## DATASET & BATCH SAMPLER

In [10]:
class ProtoDataset(Dataset):
    def __init__(self, embeddings, labels):
        assert embeddings.shape[0] == labels.shape[0]
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

## Build positive annotation set for each class

In [11]:
pos_examples = {cls_idx: [] for cls_idx in range(21)}
target_per_class = 6
debug = True

# Loop over dataset until we collect target_per_class for each class
for batch_idx, batch in enumerate(unsup_train_loader):
    raw_embs = torch.stack(batch['embeddings_raw']).to(model.device)
    attrs = torch.stack(batch['attr_labels']).to(model.device)  # shape [B,21]
    batch_size = attrs.size(0)

    for b in range(batch_size):
        attr_vector = attrs[b].clone().cpu()  # clone to avoid in-place issues
        for cls in torch.nonzero(attr_vector).flatten().tolist():
            if len(pos_examples[cls]) >= target_per_class:
                continue
            example = {
                'source_id': (batch_idx, b),
                'images_embeddings_raw': raw_embs[b].detach().cpu().clone(),
                'attr_labels': attr_vector,
                'is_positive': True
            }
            if debug:
                for key, value in example.items():
                    if torch.is_tensor(value):
                        print(f"{key}: {value.shape}")
                    elif isinstance(value, list) and len(value) and torch.is_tensor(value[0]):
                        print(f"{key}: list of {len(value)} tensors, first shape: {value[0].shape}")
                    else:
                        print(f"{key}: {type(value)}")
                debug = False
            pos_examples[cls].append(example)

    # Check if all classes reached target
    if all(len(pos_examples[c]) >= target_per_class for c in range(21)):
        break

source_id: <class 'tuple'>
images_embeddings_raw: torch.Size([2048])
attr_labels: torch.Size([21])
is_positive: <class 'bool'>


## Augment positive sets while building negative ones

In [12]:
neg_examples = {cls_idx: [] for cls_idx in range(21)}

for cls in range(21):
    seen_ids = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 1 and ex['source_id'] not in seen_ids:
                new_ex = ex.copy()
                new_ex['is_positive'] = True
                pos_examples[cls].append(new_ex)
                seen_ids.add(ex['source_id'])

for cls in range(21):
    seen_ids_pos = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 0 and ex['source_id'] not in seen_ids_pos:
                neg_ex = ex.copy()
                neg_ex['is_positive'] = False
                neg_examples[cls].append(neg_ex)

for cls in range(21):
    assert not set(ex['source_id'] for ex in neg_examples[cls]) & set(ex['source_id'] for ex in pos_examples[cls]), \
        f"Overlap in pos/neg for class {cls}"

## Construct embeddings and labels for training

In [13]:
dataset_per_class = {}
for cls in range(21):
    examples = pos_examples[cls] + neg_examples[cls]
    emb_list, label_list = [], []
    for ex in examples:
        emb_list.append(ex['images_embeddings_raw'].unsqueeze(0))
        label_list.append(ex['attr_labels'])
    embeddings_tensor = torch.stack(emb_list).to(model.device)  # [N,1,2048]
    labels_tensor = torch.stack(label_list)
    dataset_per_class[cls] = {'embeddings': embeddings_tensor.squeeze(1), 'labels': labels_tensor}
        
for cls in range(21):
    print(f"Class {cls}: embeddings shape = {dataset_per_class[cls]['embeddings'].shape}, labels shape = {dataset_per_class[cls]['labels'].shape}")

Class 0: embeddings shape = torch.Size([184, 2048]), labels shape = torch.Size([184, 21])
Class 1: embeddings shape = torch.Size([218, 2048]), labels shape = torch.Size([218, 21])
Class 2: embeddings shape = torch.Size([188, 2048]), labels shape = torch.Size([188, 21])
Class 3: embeddings shape = torch.Size([204, 2048]), labels shape = torch.Size([204, 21])
Class 4: embeddings shape = torch.Size([220, 2048]), labels shape = torch.Size([220, 21])
Class 5: embeddings shape = torch.Size([226, 2048]), labels shape = torch.Size([226, 21])
Class 6: embeddings shape = torch.Size([234, 2048]), labels shape = torch.Size([234, 21])
Class 7: embeddings shape = torch.Size([209, 2048]), labels shape = torch.Size([209, 21])
Class 8: embeddings shape = torch.Size([226, 2048]), labels shape = torch.Size([226, 21])
Class 9: embeddings shape = torch.Size([219, 2048]), labels shape = torch.Size([219, 21])
Class 10: embeddings shape = torch.Size([207, 2048]), labels shape = torch.Size([207, 21])
Class 11:

In [14]:
supervised_dataloaders = {}
for cls in range(21):
    supervised_data = dataset_per_class[cls]['embeddings']
    supervised_labels = dataset_per_class[cls]['labels']
    supervised_dataset = ProtoDataset(supervised_data, supervised_labels)
    supervised_dataloaders[cls] = DataLoader(
        supervised_dataset, 
        batch_size=args.batch_size,
        shuffle=True,       
    )

# TRAINING

## Main Loop

In [15]:
def pre_train(
        model, 
        dataloaders, 
        args, 
        eval_concepts: list = None,
        seed: int = 0
    ):

    # ^ for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False

    # ^ create optimizers and schedulers for each encoder
    enc_opts, enc_schs = {}, {}
    for c, name in enumerate(eval_concepts):
        opt = torch.optim.Adam(model.encoder[c].parameters())
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)
        enc_opts[c], enc_schs[c] = opt, sch

    # ^ move each encoder to device
    for b in range(len(model.encoder)):
        model.encoder[b].train()
        model.encoder[b].to(model.device)
    
    fprint("\n--- Start of PreTraining ---\n")
    for epoch in range(args.proto_epochs):
        
        # ^ for each concept extractor, train its encoder
        for k, name in enumerate(eval_concepts):
            dl = dataloaders[k]
            opt, sch = enc_opts[k], enc_schs[k]
            fprint(f"\n--- Pretraining of {name} ---\n")
            
            # training
            for i, batch in enumerate(dl):
                batch_embeds, batch_labels = batch
                batch_embeds = batch_embeds.to(model.device)
                batch_labels = batch_labels.to(model.device)

                opt.zero_grad()
                preds = model.encoder[k](batch_embeds)
                assert preds.shape == (batch_embeds.shape[0], 1),\
                    f"Expected shape ({batch_embeds.shape[0]}, 1), got {preds.shape}"
                
                loss = F.binary_cross_entropy(
                        preds.squeeze(1), 
                        batch_labels[:, k].float()
                    )
                loss.backward()
                opt.step()

                progress_bar(i, len(dl), epoch, loss.item())
            
            sch.step()

            # ^ evaluation
            model.encoder[k].eval()
            eval_loss, correct, total = 0.0, 0, 0
            with torch.no_grad():
                for batch in dl:  # if you have a validation loader, replace with that
                    batch_embeds, batch_labels = batch
                    batch_embeds = batch_embeds.to(model.device)
                    batch_labels = batch_labels.to(model.device)

                    preds = model.encoder[k](batch_embeds).squeeze(1)
                    eval_loss += F.binary_cross_entropy(
                        preds, batch_labels[:, k].float(), reduction="sum"
                    ).item()

                    pred_labels = (preds > 0.5).long()
                    correct += (pred_labels == batch_labels[:, k]).sum().item()
                    total += batch_embeds.size(0)
            
            avg_loss = eval_loss / total
            accuracy = correct / total
            fprint(f"[Epoch {epoch+1}] {name} Eval Loss: {avg_loss:.4f}, "
                   f"Accuracy: {accuracy:.4f}")
            
            model.encoder[k].train()

In [16]:
eval_concepts = ['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
                'on_right_turn_lane', 'traffic_light_right', 'front_car_right', 
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
                'on_left_turn_lane', 'traffic_light_left', 'front_car_left']

pre_train(model, supervised_dataloaders, args, 
    eval_concepts=eval_concepts, seed=args.seed)


--- Start of PreTraining ---


--- Pretraining of green_lights ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.22626159

[Epoch 1] green_lights Eval Loss: 1.3723, Accuracy: 0.8967

--- Pretraining of follow_traffic ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.56628591

[Epoch 1] follow_traffic Eval Loss: 0.4797, Accuracy: 0.9495

--- Pretraining of road_clear ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 1.20516419

[Epoch 1] road_clear Eval Loss: 0.9896, Accuracy: 0.9043

--- Pretraining of traffic_lights ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.35032409

[Epoch 1] traffic_lights Eval Loss: 0.7351, Accuracy: 0.9314

--- Pretraining of traffic_signs ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.51919723

[Epoch 1] traffic_signs Eval Loss: 0.6463, Accuracy: 0.9500

--- Pretraining of cars ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.38778484

[Epoch 1] cars Eval Loss: 0.4130, Accuracy: 0.9735

--- Pretraining of pedestrians ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.11471567

[Epoch 1] pedestrians Eval Loss: 0.2503, Accuracy: 0.9744

--- Pretraining of riders ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.66693449

[Epoch 1] riders Eval Loss: 0.9948, Accuracy: 0.9234

--- Pretraining of others ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.26868469

[Epoch 1] others Eval Loss: 0.2457, Accuracy: 0.9735

--- Pretraining of no_lane_left ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.00167267

[Epoch 1] no_lane_left Eval Loss: 0.4604, Accuracy: 0.9726

--- Pretraining of obstacle_left_lane ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.25661656

[Epoch 1] obstacle_left_lane Eval Loss: 0.6085, Accuracy: 0.9469

--- Pretraining of solid_left_line ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.59500794

[Epoch 1] solid_left_line Eval Loss: 0.2787, Accuracy: 0.9733

--- Pretraining of on_right_turn_lane ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.00160225

[Epoch 1] on_right_turn_lane Eval Loss: 0.3944, Accuracy: 0.9732

--- Pretraining of traffic_light_right ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.30523866

[Epoch 1] traffic_light_right Eval Loss: 0.5475, Accuracy: 0.9469

--- Pretraining of front_car_right ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.20850615

[Epoch 1] front_car_right Eval Loss: 0.2594, Accuracy: 0.9732

--- Pretraining of no_lane_right ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.63265216

[Epoch 1] no_lane_right Eval Loss: 1.0157, Accuracy: 0.9300

--- Pretraining of obstacle_right_lane ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 1.30599546

[Epoch 1] obstacle_right_lane Eval Loss: 0.9761, Accuracy: 0.9072

--- Pretraining of solid_right_line ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.16879353

[Epoch 1] solid_right_line Eval Loss: 0.6286, Accuracy: 0.9526

--- Pretraining of on_left_turn_lane ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 1.40134144

[Epoch 1] on_left_turn_lane Eval Loss: 0.9723, Accuracy: 0.9137

--- Pretraining of traffic_light_left ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.49830231

[Epoch 1] traffic_light_left Eval Loss: 0.8396, Accuracy: 0.9058

--- Pretraining of front_car_left ---



[ 09-04 | 15:32 ] epoch 0: |██████████████████████████████████████████████████| loss: 0.25576407

[Epoch 1] front_car_left Eval Loss: 0.3976, Accuracy: 0.9633


In [17]:
def train(
        model: MnistDPL, 
        _loss: ADDMNIST_DPL,
        save_path: str,
        train_loader: DataLoader,
        val_loader: DataLoader,
        args: Namespace,
        eval_concepts: list = None,
        seed: int = 0,
    ) -> float:

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    # early stopping
    best_cacc = 0.0
    epochs_no_improve = 0
    
    # scheduler & warmup (not used) for main model
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:
        w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")
    model.to(model.device)
    model.opt.zero_grad()
    model.opt.step()

    # & Training start
    for epoch in range(args.n_epochs):
        print(f"Epoch {epoch+1}/{args.n_epochs}")
        
        model.train()

        # * Unsupervised Training
        print("Unsupervised training phase")
        ys, y_true, cs, cs_true, batch = None, None, None, None, 0
        for i, batch in enumerate(train_loader):
            
            # ------------------ original embneddings
            images_embeddings = torch.stack(batch['embeddings']).to(model.device)
            attr_labels = torch.stack(batch['attr_labels']).to(model.device)
            class_labels = torch.stack(batch['class_labels'])[:,:-1].to(model.device)
            # ------------------ my extracted features
            images_embeddings_raw = torch.stack(batch['embeddings_raw']).to(model.device)
            detected_rois = batch['rois']
            detected_rois_feats = batch['roi_feats']
            detection_labels = batch['detection_labels']
            detection_scores = batch['detection_scores']
            assert_inputs(images_embeddings, attr_labels, class_labels,
                   detected_rois_feats, detected_rois, detection_labels,
                   detection_scores, images_embeddings_raw)

            out_dict = model(images_embeddings_raw)
            out_dict.update({"LABELS": class_labels, "CONCEPTS": attr_labels})
            
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)

            loss.backward()
            model.opt.step()

            if ys is None:
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(train_loader) - 9, epoch, loss.item())
            
        # ^ Evaluation phase
        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        model.eval()
        
        my_metrics = evaluate_metrics(
                model=model, 
                loader=val_loader, 
                args=args,
                eval_concepts=eval_concepts)
        loss = my_metrics[0]
        cacc = my_metrics[1]
        yacc = my_metrics[2]
        f1_y = my_metrics[3]
       
        # update at end of the epoch
        if epoch < args.warmup_steps:   w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1_y)
        
        if not args.tuning and cacc > best_cacc:
            print("Saving...")
            # Update best F1 score
            best_cacc = cacc
            epochs_no_improve = 0

            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with CACC score: {best_cacc}")

        elif cacc <= best_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break


    fprint("\n--- End of Training ---\n")
    return best_cacc

## Run Training

In [18]:
print(f"*** Training model with seed {args.seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{args.seed}.pth")
print("Saving model in folder: ", save_folder)

eval_concepts = ['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
                'on_right_turn_lane', 'traffic_light_right', 'front_car_right', 
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
                'on_left_turn_lane', 'traffic_light_left', 'front_car_left']

best_cacc = train(
        model=model,
        train_loader=unsup_train_loader,
        val_loader=unsup_val_loader,
        save_path=save_folder,
        _loss=loss,
        args=args,
        eval_concepts=eval_concepts,
        seed=args.seed,
)
save_model(model, args, args.seed)  # save the model parameters
print(f"*** Finished training model with seed {args.seed} and best CACC score {best_cacc}")

print("Training finished.")

*** Training model with seed 0
Chosen device: cuda
Saving model in folder:  ../NEW-outputs/bddoia/baseline/dpl/baseline-disj-PRE/dpl_0.pth

--- Start of Training ---

Epoch 1/40
Unsupervised training phase


[ 09-04 | 15:33 ] epoch 0: |██████████████████████████████████████████████████| loss: 4.75548792

  ACC C 84.26358203093211   ACC Y 70.22569444444444 F1 Y 56.272930960117975
Saving...
Saved best model with CACC score: 84.26358203093211
Epoch 2/40
Unsupervised training phase


[ 09-04 | 15:33 ] epoch 1: |██████████████████████████████████████████████████| loss: 4.77326536

  ACC C 84.58450006114111   ACC Y 63.07705965909091 F1 Y 60.46534366922825
Saving...
Saved best model with CACC score: 84.58450006114111
Epoch 3/40
Unsupervised training phase


[ 09-04 | 15:34 ] epoch 2: |██████████████████████████████████████████████████| loss: 4.73955631

  ACC C 84.31412511401706   ACC Y 61.945628156565654 F1 Y 59.56880897456301
Epoch 4/40
Unsupervised training phase


[ 09-04 | 15:35 ] epoch 3: |██████████████████████████████████████████████████| loss: 4.72808266

  ACC C 84.55387353897095   ACC Y 68.2360716540404 F1 Y 60.59219222652515
Epoch 5/40
Unsupervised training phase


[ 09-04 | 15:35 ] epoch 4: |██████████████████████████████████████████████████| loss: 4.74091482

  ACC C 84.75322524706523   ACC Y 70.09548611111111 F1 Y 60.198208391633166
Saving...
Saved best model with CACC score: 84.75322524706523
Epoch 6/40
Unsupervised training phase


[ 09-04 | 15:36 ] epoch 5: |██████████████████████████████████████████████████| loss: 4.73691702

  ACC C 84.62301790714264   ACC Y 67.42128314393939 F1 Y 61.309797683939095
Epoch 7/40
Unsupervised training phase


[ 09-04 | 15:36 ] epoch 6: |██████████████████████████████████████████████████| loss: 4.75249863

  ACC C 84.74477032820384   ACC Y 70.63012941919192 F1 Y 60.562428300076554
Epoch 8/40
Unsupervised training phase


[ 09-04 | 15:37 ] epoch 7: |██████████████████████████████████████████████████| loss: 4.71096706

  ACC C 84.52700542079077   ACC Y 68.01215277777777 F1 Y 60.71332717960325
Epoch 9/40
Unsupervised training phase


[ 09-04 | 15:38 ] epoch 8: |██████████████████████████████████████████████████| loss: 4.73797703

  ACC C 84.69742238521576   ACC Y 74.47620738636364 F1 Y 52.39877291409552
Epoch 10/40
Unsupervised training phase


[ 09-04 | 15:38 ] epoch 9: |██████████████████████████████████████████████████| loss: 4.72500801

  ACC C 84.67487461037106   ACC Y 73.57855902777777 F1 Y 57.120519359803566
Epoch 11/40
Unsupervised training phase


[ 09-04 | 15:39 ] epoch 10: |██████████████████████████████████████████████████| loss: 4.72138739

  ACC C 84.54748491446178   ACC Y 74.89938446969697 F1 Y 45.741515916244566
Epoch 12/40
Unsupervised training phase


[ 09-04 | 15:39 ] epoch 11: |██████████████████████████████████████████████████| loss: 4.72024727

  ACC C 82.03087581528558   ACC Y 72.60396938131312 F1 Y 50.504837374827076
Epoch 13/40
Unsupervised training phase


[ 09-04 | 15:40 ] epoch 12: |██████████████████████████████████████████████████| loss: 4.71095896

  ACC C 83.11744762791528   ACC Y 72.54872948232324 F1 Y 50.52846288802723
Epoch 14/40
Unsupervised training phase


[ 09-04 | 15:41 ] epoch 13: |██████████████████████████████████████████████████| loss: 4.67393446

  ACC C 82.67458875974019   ACC Y 70.48413825757576 F1 Y 51.06568470190591
Epoch 15/40
Unsupervised training phase


[ 09-04 | 15:41 ] epoch 14: |██████████████████████████████████████████████████| loss: 4.66131021

  ACC C 82.36795167128246   ACC Y 71.12334280303031 F1 Y 50.1032015445642
Epoch 16/40
Unsupervised training phase


[ 09-04 | 15:42 ] epoch 15: |██████████████████████████████████████████████████| loss: 4.72431707

  ACC C 81.96079267395868   ACC Y 70.53740530303031 F1 Y 50.70693699543647
Epoch 17/40
Unsupervised training phase


[ 09-04 | 15:42 ] epoch 16: |██████████████████████████████████████████████████| loss: 4.75232416

  ACC C 81.46870599852667   ACC Y 72.17388731060606 F1 Y 49.9049892069551
Epoch 18/40
Unsupervised training phase


[ 09-04 | 15:43 ] epoch 17: |██████████████████████████████████████████████████| loss: 4.72701745

  ACC C 82.10133446587457   ACC Y 71.70434816919192 F1 Y 51.194174209344744
Epoch 19/40
Unsupervised training phase


[ 09-04 | 15:44 ] epoch 18: |██████████████████████████████████████████████████| loss: 4.65083456

  ACC C 82.26874536938138   ACC Y 70.55220170454545 F1 Y 50.44866711026231
Epoch 20/40
Unsupervised training phase


[ 09-04 | 15:44 ] epoch 19: |██████████████████████████████████████████████████| loss: 4.71089554

  ACC C 82.72118634647794   ACC Y 72.69077493686868 F1 Y 50.335723320339916
Early stopping triggered after 20 epochs.

--- End of Training ---

*** Finished training model with seed 0 and best CACC score 84.75322524706523
Training finished.


# TESTING

## Evaluation Routine

In [19]:
def evaluate_my_model(model: MnistDPL, 
        save_path: str, 
        test_loader: DataLoader,
        eval_concepts,
    ):
    
    my_metrics = evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,)
    
    loss = my_metrics[0]
    cacc = my_metrics[1]
    yacc = my_metrics[2]
    f1_y = my_metrics[3]
    f1_micro = my_metrics[4]
    f1_weight = my_metrics[5]
    f1_bin = my_metrics[6]

    metrics_log_path = save_path.replace(".pth", "_metrics.log")
    
    all_concepts = [ 'Green Traffic Light', 'Follow Traffic', 'Road Is Clear',
        'Red Traffic Light', 'Traffic Sign', 'Obstacle Car', 'Obstacle Pedestrian', 'Obstacle Rider', 'Obstacle Others',
        'No Lane On The Left',  'Obstacle On The Left Lane',  'Solid Left Line',
                'On The Right Turn Lane', 'Traffic Light Allows Right', 'Front Car Turning Right', 
        'No Lane On The Right', 'Obstacle On The Right Lane', 'Solid Right Line',
                'On The Left Turn Lane',  'Traffic Light Allows Left',  'Front Car Turning Left' 
    ]
    aggregated_metrics = [
            'F1 - Binary', 'F1 - Macro', 'F1 - Micro', 'F1 - Weighted',
            'Precision - Binary', 'Precision - Macro', 'Precision - Micro', 'Precision - Weighted',
            'Recall - Binary', 'Recall - Macro', 'Recall - Micro', 'Recall - Weighted',
            'Balanced Accuracy'
    ]

    sums = [0.0] * len(aggregated_metrics)
    num_concepts = len(all_concepts)
    with open(metrics_log_path, "a") as log_file:
        log_file.write(f"ACC C: {cacc}, ACC Y: {yacc}\n\n")
        log_file.write(f"F1 Y - Macro: {f1_y}, F1 Y - Micro: {f1_micro}, F1 Y - Weighted: {f1_weight}, F1 Y - Binary: {f1_bin} \n\n")

        def write_metrics(class_name, offset):
            print(f"Metrics for {class_name}:")
            log_file.write(f"{class_name.upper()}\n")
            for idx, metric_name in enumerate(aggregated_metrics):
                value = my_metrics[offset + idx]
                sums[idx] += value
                log_file.write(f"  {metric_name:<18} {value:.4f}\n")
            log_file.write("\n")

        i = 7
        for concept in all_concepts:
            write_metrics(concept, i)
            i += len(aggregated_metrics)

        log_file.write("MEAN ACROSS ALL CONCEPTS\n")
        for idx, metric_name in enumerate(aggregated_metrics):
            mean_value = sums[idx] / num_concepts
            log_file.write(f"  {metric_name:<18} {mean_value:.4f}\n")
        log_file.write("\n")


    assert len(my_metrics) == 7 + len(all_concepts) * len(aggregated_metrics), \
        f"Expected {7 + len(all_concepts) * len(aggregated_metrics)} metrics, but got {len(my_metrics)}"
    y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
        evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,
                    last=True
            )
    )
    y_labels = ["stop", "forward", "left", "right"]
    concept_labels = [
        "green_light",      
        "follow",           
        "road_clear",       
        "red_light",        
        "traffic_sign",     
        "car",              
        "person",           
        "rider",            
        "other_obstacle",   
        "left_lane",
        "left_green_light",
        "left_follow",
        "no_left_lane",
        "left_obstacle",
        "letf_solid_line",
        "right_lane",
        "right_green_light",
        "right_follow",
        "no_right_lane",
        "right_obstacle",
        "right_solid_line",
    ]

    plot_multilabel_confusion_matrix(y_true, y_pred, y_labels, "Labels", save_path=save_path)
    cfs = plot_actions_confusion_matrix(c_true, c_pred, "Concepts", save_path=save_path)
    cf = plot_multilabel_confusion_matrix(c_true, c_pred, concept_labels, "Concepts", save_path=save_path)
    print("Concept collapse", 1 - compute_coverage(cf))

    with open(metrics_log_path, "a") as log_file:
        for key, value in cfs.items():
            log_file.write(f"Concept collapse: {key}, {1 - compute_coverage(value):.4f}\n")
            log_file.write("\n")

    fprint("\n--- End of Evaluation ---\n")

## Run Evaluation

In [20]:
# Initialize the model object
model = get_model(args, encoder, decoder, n_images, c_split)

# Load the model state dictionary into the model object
model_state_dict = torch.load(save_folder)
model.load_state_dict(model_state_dict)

evaluate_my_model(model, save_folder, unsup_test_loader, eval_concepts=eval_concepts)

Available models: ['promnistltn', 'promnmathcbm', 'sddoiann', 'kandnn', 'sddoiadpl', 'sddoialtn', 'kandslsingledisj', 'presddoiadpl', 'boiann', 'mnistclip', 'prokanddpl', 'promnistdpl', 'kandltnsinglejoint', 'xornn', 'mnistnn', 'mnistslrec', 'kandpreprocess', 'kandsl', 'kandsloneembedding', 'prokandltn', 'kandcbm', 'prokandsl', 'boiacbm', 'kanddpl', 'kandltn', 'xorcbm', 'sddoiaclip', 'kanddplsinglejoint', 'xordpl', 'promnmathdpl', 'bddoiadpldisj', 'sddoiacbm', 'mnistltnrec', 'mnmathcbm', 'mnmathdpl', 'kandclip', 'minikanddpl', 'mnistdpl', 'mnistltn', 'boiadpl', 'boialtn', 'kandltnsingledisj', 'prokandsloneembedding', 'mnistpcbmdpl', 'mnistcbm', 'probddoiadpl', 'mnistpcbmsl', 'mnistpcbmltn', 'kanddplsingledisj', 'mnistsl', 'kandslsinglejoint', 'mnistdplrec', 'cvae', 'cext', 'mnmathnn', 'promnistsl']
Metrics for Green Traffic Light:
Metrics for Follow Traffic:
Metrics for Road Is Clear:
Metrics for Red Traffic Light:
Metrics for Traffic Sign:
Metrics for Obstacle Car:
Metrics for Obstacl