# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append(os.path.abspath(".."))       # for 'protonet_STOP_bddoia_modules' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

In [None]:
import csv
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import datetime
import numpy as np
import setproctitle, socket, uuid

from models import get_model
from models.mnistdpl import MnistDPL
from datasets import get_dataset

from argparse import Namespace
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import multilabel_confusion_matrix, confusion_matrix

from utils import fprint
from utils.status import progress_bar
from utils.metrics import evaluate_metrics
from utils.dpl_loss import ADDMNIST_DPL
from utils.checkpoint import save_model
from torch.utils.data import Dataset, DataLoader

from warmup_scheduler import GradualWarmupScheduler
from baseline_modules.arguments import args_dpl 
from backbones.bddoia_protonet import ProtoNetConv1D, PrototypicalLoss
from protonet_STOP_bddoia_modules.proto_modules.proto_helpers import assert_inputs

# SETUP

In [None]:
args = args_dpl
args.seed = 2048

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "bddoia" 
save_model_name = 'dpl'
save_paths = []
save_path = os.path.join("..",
    "NEW-outputs", 
    save_folder, 
    "baseline", 
    save_model_name,
    f"PRE-baseline"
)
save_paths.append(save_path)

print("Seed: " + str(args.seed))
print(f"Save paths: {str(save_paths)}")

# UTILS

## Test Set Evaluation (alternative to full fledged notebook eval)

In [None]:

# * helper function for 'plot_multilabel_confusion_matrix'
def convert_to_categories(elements):
    # Convert vector of 0s and 1s to a single binary representation along the first dimension
    binary_rep = np.apply_along_axis(
        lambda x: "".join(map(str, x)), axis=1, arr=elements
    )
    return np.array([int(x, 2) for x in binary_rep])


# * BBDOIA custom confusion matrix for concepts
def plot_multilabel_confusion_matrix(
    y_true, y_pred, class_names, title, save_path=None
):
    y_true_categories = convert_to_categories(y_true.astype(int))
    y_pred_categories = convert_to_categories(y_pred.astype(int))

    to_rtn_cm = confusion_matrix(y_true_categories, y_pred_categories)

    cm = multilabel_confusion_matrix(y_true, y_pred)
    num_classes = len(class_names)
    num_rows = (num_classes + 4) // 5  # Calculate the number of rows needed

    plt.figure(figsize=(20, 4 * num_rows))  # Adjust the figure size

    for i in range(num_classes):
        plt.subplot(num_rows, 5, i + 1)  # Set the subplot position
        plt.imshow(cm[i], interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"Class: {class_names[i]}")
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ["0", "1"])
        plt.yticks(tick_marks, ["0", "1"])

        fmt = ".0f"
        thresh = cm[i].max() / 2.0
        for j in range(cm[i].shape[0]):
            for k in range(cm[i].shape[1]):
                plt.text(
                    k,
                    j,
                    format(cm[i][j, k], fmt),
                    ha="center",
                    va="center",
                    color="white" if cm[i][j, k] > thresh else "black",
                )

        plt.ylabel("True label")
        plt.xlabel("Predicted label")

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.suptitle(title)

    if save_path:
        plt.savefig(f"{save_path}_total.png")
    else:
        plt.show()

    plt.close()

    return to_rtn_cm


# * Concept collapse (Soft)
def compute_coverage(confusion_matrix):
    """Compute the coverage of a confusion matrix.

    Essentially this metric is
    """

    max_values = np.max(confusion_matrix, axis=0)
    clipped_values = np.clip(max_values, 0, 1)

    # Redefinition of soft coverage
    coverage = np.sum(clipped_values) / len(clipped_values)

    return coverage


# * BDDOIA custom confusion matrix for actions
def plot_actions_confusion_matrix(c_true, c_pred, title, save_path=None):

    # Define scenarios and corresponding labels
    my_scenarios = {
        "forward": [slice(0, 3), slice(0, 3)],  
        "stop": [slice(3, 9), slice(3, 9)],
        "left": [slice(9, 11), slice(18,20)],
        "right": [slice(12, 17), slice(12,17)],
    }

    to_rtn = {}

    # Plot confusion matrix for each scenario
    for scenario, indices in my_scenarios.items():

        g_true = convert_to_categories(c_true[:, indices[0]].astype(int))
        c_pred_scenario = convert_to_categories(c_pred[:, indices[1]].astype(int))

        # Compute confusion matrix
        cm = confusion_matrix(g_true, c_pred_scenario)

        # Plot confusion matrix
        plt.figure()
        plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
        plt.title(f"{title} - {scenario}")
        plt.colorbar()

        n_classes = c_true[:, indices[0]].shape[1]

        tick_marks = np.arange(2**n_classes)
        plt.xticks(tick_marks, ["" for _ in range(len(tick_marks))])
        plt.yticks(tick_marks, ["" for _ in range(len(tick_marks))])

        plt.ylabel("True label")
        plt.xlabel("Predicted label")
        plt.tight_layout()

        # Save or show plot
        if save_path:
            plt.savefig(f"{save_path}_{scenario}.png")
        else:
            plt.show()

        to_rtn.update({scenario: cm})

        plt.close()

    return to_rtn

# DATA LOADING

In [None]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()

encoder, decoder = dataset.get_backbone()
model = get_model(args, encoder, decoder, n_images, c_split)
model.start_optim(args)
loss = model.get_loss(args)

print(dataset)
print("Using Dataset: ", dataset)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)

unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders(args=args)

# PROTOTYPES CONSTRUCTION

## DATASET & BATCH SAMPLER

In [None]:
class ProtoDataset(Dataset):
    def __init__(self, embeddings, labels):
        assert embeddings.shape[0] == labels.shape[0]
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

## Build positive annotation set for each class

In [None]:
pos_examples = {cls_idx: [] for cls_idx in range(21)}
target_per_class = 6
debug = True

# Loop over dataset until we collect target_per_class for each class
for batch_idx, batch in enumerate(unsup_train_loader):
    raw_embs = torch.stack(batch['embeddings_raw']).to(model.device)
    attrs = torch.stack(batch['attr_labels']).to(model.device)  # shape [B,21]
    batch_size = attrs.size(0)

    for b in range(batch_size):
        attr_vector = attrs[b].clone().cpu()  # clone to avoid in-place issues
        for cls in torch.nonzero(attr_vector).flatten().tolist():
            if len(pos_examples[cls]) >= target_per_class:
                continue
            example = {
                'source_id': (batch_idx, b),
                'images_embeddings_raw': raw_embs[b].detach().cpu().clone(),
                'attr_labels': attr_vector,
                'is_positive': True
            }
            if debug:
                for key, value in example.items():
                    if torch.is_tensor(value):
                        print(f"{key}: {value.shape}")
                    elif isinstance(value, list) and len(value) and torch.is_tensor(value[0]):
                        print(f"{key}: list of {len(value)} tensors, first shape: {value[0].shape}")
                    else:
                        print(f"{key}: {type(value)}")
                debug = False
            pos_examples[cls].append(example)

    # Check if all classes reached target
    if all(len(pos_examples[c]) >= target_per_class for c in range(21)):
        break

## Augment positive sets while building negative ones

In [None]:
neg_examples = {cls_idx: [] for cls_idx in range(21)}

for cls in range(21):
    seen_ids = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 1 and ex['source_id'] not in seen_ids:
                new_ex = ex.copy()
                new_ex['is_positive'] = True
                pos_examples[cls].append(new_ex)
                seen_ids.add(ex['source_id'])

for cls in range(21):
    seen_ids_pos = {ex['source_id'] for ex in pos_examples[cls]}
    for other_cls in range(21):
        if other_cls == cls:
            continue
        for ex in pos_examples[other_cls]:
            if ex['attr_labels'][cls] == 0 and ex['source_id'] not in seen_ids_pos:
                neg_ex = ex.copy()
                neg_ex['is_positive'] = False
                neg_examples[cls].append(neg_ex)

for cls in range(21):
    assert not set(ex['source_id'] for ex in neg_examples[cls]) & set(ex['source_id'] for ex in pos_examples[cls]), \
        f"Overlap in pos/neg for class {cls}"

## Construct embeddings and labels for training

In [None]:
dataset_per_class = {}
for cls in range(21):
    examples = pos_examples[cls] + neg_examples[cls]
    emb_list, label_list = [], []
    for ex in examples:
        emb_list.append(ex['images_embeddings_raw'].unsqueeze(0))
        label_list.append(ex['attr_labels'])
    embeddings_tensor = torch.stack(emb_list).to(model.device)  # [N,1,2048]
    labels_tensor = torch.stack(label_list)
    dataset_per_class[cls] = {'embeddings': embeddings_tensor.squeeze(1), 'labels': labels_tensor}
        
for cls in range(21):
    print(f"Class {cls}: embeddings shape = {dataset_per_class[cls]['embeddings'].shape}, labels shape = {dataset_per_class[cls]['labels'].shape}")

In [None]:
all_embeddings = torch.cat([per_class['embeddings'] for per_class in dataset_per_class.values()], dim=0)
all_labels = torch.cat([per_class['labels'] for per_class in dataset_per_class.values()], dim=0)
all_embeddings = all_embeddings.cpu()
all_labels     = all_labels.cpu()
dataset = ProtoDataset(all_embeddings, all_labels)
nn_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,        
    pin_memory=True       
)

# TRAINING

## Main Loop

In [None]:
def pre_train(model, train_loader, args, seed: int = 0):

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False

    # start optimizer
    enc_opt = torch.optim.Adam(model.encoder.parameters(), args.lr, weight_decay=args.weight_decay)

    fprint("\n--- Start of Training ---\n")
    model.encoder.to(model.device)
    
    for epoch in range(args.n_epochs):
        for i, batch in enumerate(train_loader):
            batch_embeds, batch_labels = batch
            batch_embeds = batch_embeds.to(model.device)
            batch_labels = batch_labels.to(model.device)

            enc_opt.zero_grad()
            preds = model.encoder(batch_embeds)
            assert preds.shape == (batch_embeds.shape[0], 21), f"Expected shape ({batch_embeds.shape[0]}, 21), got {preds.shape}"
            loss = F.binary_cross_entropy(preds, batch_labels.float())

            loss.backward()
            enc_opt.step()

            progress_bar(i, len(train_loader), epoch, loss.item())


pre_train(model, nn_loader, args)

In [None]:
def train(
        model: MnistDPL, 
        _loss: ADDMNIST_DPL,
        save_path: str,
        train_loader: DataLoader,
        val_loader: DataLoader,
        args: Namespace,
        eval_concepts: list = None,
        seed: int = 0,
    ) -> float:

    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    # early stopping
    best_cacc = 0.0
    epochs_no_improve = 0
    
    # scheduler & warmup (not used) for main model
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:
        w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")
    model.to(model.device)
    model.opt.zero_grad()
    model.opt.step()

    # & Training start
    for epoch in range(args.n_epochs):
        print(f"Epoch {epoch+1}/{args.n_epochs}")
        
        model.train()

        # * Unsupervised Training
        print("Unsupervised training phase")
        ys, y_true, cs, cs_true, batch = None, None, None, None, 0
        for i, batch in enumerate(train_loader):
            
            # ------------------ original embneddings
            images_embeddings = torch.stack(batch['embeddings']).to(model.device)
            attr_labels = torch.stack(batch['attr_labels']).to(model.device)
            class_labels = torch.stack(batch['class_labels'])[:,:-1].to(model.device)
            # ------------------ my extracted features
            images_embeddings_raw = torch.stack(batch['embeddings_raw']).to(model.device)
            detected_rois = batch['rois']
            detected_rois_feats = batch['roi_feats']
            detection_labels = batch['detection_labels']
            detection_scores = batch['detection_scores']
            assert_inputs(images_embeddings, attr_labels, class_labels,
                   detected_rois_feats, detected_rois, detection_labels,
                   detection_scores, images_embeddings_raw)

            out_dict = model(images_embeddings_raw)
            out_dict.update({"LABELS": class_labels, "CONCEPTS": attr_labels})
            
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)

            loss.backward()
            model.opt.step()

            if ys is None:
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(train_loader) - 9, epoch, loss.item())
            
        # ^ Evaluation phase
        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        model.eval()
        
        my_metrics = evaluate_metrics(
                model=model, 
                loader=val_loader, 
                args=args,
                eval_concepts=eval_concepts)
        loss = my_metrics[0]
        cacc = my_metrics[1]
        yacc = my_metrics[2]
        f1_y = my_metrics[3]
       
        # update at end of the epoch
        if epoch < args.warmup_steps:   w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1_y)
        
        if not args.tuning and cacc > best_cacc:
            print("Saving...")
            # Update best F1 score
            best_cacc = cacc
            epochs_no_improve = 0

            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with CACC score: {best_cacc}")

        elif cacc <= best_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break


    fprint("\n--- End of Training ---\n")
    return best_cacc

## Run Training

In [None]:
print(f"*** Training model with seed {args.seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{args.seed}.pth")
print("Saving model in folder: ", save_folder)

eval_concepts = ['green_lights', 'follow_traffic', 'road_clear',
        'traffic_lights', 'traffic_signs', 'cars', 'pedestrians', 'riders', 'others',
        'no_lane_left', 'obstacle_left_lane', 'solid_left_line',
                'on_right_turn_lane', 'traffic_light_right', 'front_car_right', 
        'no_lane_right', 'obstacle_right_lane', 'solid_right_line',
                'on_left_turn_lane', 'traffic_light_left', 'front_car_left']

best_cacc = train(
        model=model,
        train_loader=unsup_train_loader,
        val_loader=unsup_val_loader,
        save_path=save_folder,
        _loss=loss,
        args=args,
        eval_concepts=eval_concepts,
        seed=args.seed,
)
save_model(model, args, args.seed)  # save the model parameters
print(f"*** Finished training model with seed {args.seed} and best CACC score {best_cacc}")

print("Training finished.")

# TESTING

## Evaluation Routine

In [None]:
def evaluate_my_model(model: MnistDPL, 
        save_path: str, 
        test_loader: DataLoader,
        eval_concepts,
    ):
    
    my_metrics = evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,)
    
    loss = my_metrics[0]
    cacc = my_metrics[1]
    yacc = my_metrics[2]
    f1_y = my_metrics[3]
    f1_micro = my_metrics[4]
    f1_weight = my_metrics[5]
    f1_bin = my_metrics[6]

    metrics_log_path = save_path.replace(".pth", "_metrics.log")
    
    all_concepts = [ 'Green Traffic Light', 'Follow Traffic', 'Road Is Clear',
        'Red Traffic Light', 'Traffic Sign', 'Obstacle Car', 'Obstacle Pedestrian', 'Obstacle Rider', 'Obstacle Others',
        'No Lane On The Left',  'Obstacle On The Left Lane',  'Solid Left Line',
                'On The Right Turn Lane', 'Traffic Light Allows Right', 'Front Car Turning Right', 
        'No Lane On The Right', 'Obstacle On The Right Lane', 'Solid Right Line',
                'On The Left Turn Lane',  'Traffic Light Allows Left',  'Front Car Turning Left' 
    ]
    aggregated_metrics = [
            'F1 - Binary', 'F1 - Macro', 'F1 - Micro', 'F1 - Weighted',
            'Precision - Binary', 'Precision - Macro', 'Precision - Micro', 'Precision - Weighted',
            'Recall - Binary', 'Recall - Macro', 'Recall - Micro', 'Recall - Weighted',
            'Balanced Accuracy'
    ]

    sums = [0.0] * len(aggregated_metrics)
    num_concepts = len(all_concepts)
    with open(metrics_log_path, "a") as log_file:
        log_file.write(f"ACC C: {cacc}, ACC Y: {yacc}\n\n")
        log_file.write(f"F1 Y - Macro: {f1_y}, F1 Y - Micro: {f1_micro}, F1 Y - Weighted: {f1_weight}, F1 Y - Binary: {f1_bin} \n\n")

        def write_metrics(class_name, offset):
            print(f"Metrics for {class_name}:")
            log_file.write(f"{class_name.upper()}\n")
            for idx, metric_name in enumerate(aggregated_metrics):
                value = my_metrics[offset + idx]
                sums[idx] += value
                log_file.write(f"  {metric_name:<18} {value:.4f}\n")
            log_file.write("\n")

        i = 7
        for concept in all_concepts:
            write_metrics(concept, i)
            i += len(aggregated_metrics)

        log_file.write("MEAN ACROSS ALL CONCEPTS\n")
        for idx, metric_name in enumerate(aggregated_metrics):
            mean_value = sums[idx] / num_concepts
            log_file.write(f"  {metric_name:<18} {mean_value:.4f}\n")
        log_file.write("\n")


    assert len(my_metrics) == 7 + len(all_concepts) * len(aggregated_metrics), \
        f"Expected {7 + len(all_concepts) * len(aggregated_metrics)} metrics, but got {len(my_metrics)}"
    y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = (
        evaluate_metrics(model, test_loader, args, 
                    eval_concepts=eval_concepts,
                    last=True
            )
    )
    y_labels = ["stop", "forward", "left", "right"]
    concept_labels = [
        "green_light",      
        "follow",           
        "road_clear",       
        "red_light",        
        "traffic_sign",     
        "car",              
        "person",           
        "rider",            
        "other_obstacle",   
        "left_lane",
        "left_green_light",
        "left_follow",
        "no_left_lane",
        "left_obstacle",
        "letf_solid_line",
        "right_lane",
        "right_green_light",
        "right_follow",
        "no_right_lane",
        "right_obstacle",
        "right_solid_line",
    ]

    plot_multilabel_confusion_matrix(y_true, y_pred, y_labels, "Labels", save_path=save_path)
    cfs = plot_actions_confusion_matrix(c_true, c_pred, "Concepts", save_path=save_path)
    cf = plot_multilabel_confusion_matrix(c_true, c_pred, concept_labels, "Concepts", save_path=save_path)
    print("Concept collapse", 1 - compute_coverage(cf))

    with open(metrics_log_path, "a") as log_file:
        for key, value in cfs.items():
            log_file.write(f"Concept collapse: {key}, {1 - compute_coverage(value):.4f}\n")
            log_file.write("\n")

    fprint("\n--- End of Evaluation ---\n")

## Run Evaluation

In [None]:
# Initialize the model object
model = get_model(args, encoder, decoder, n_images, c_split)

# Load the model state dictionary into the model object
model_state_dict = torch.load(save_folder)
model.load_state_dict(model_state_dict)

evaluate_my_model(model, save_folder, unsup_test_loader, eval_concepts=eval_concepts)