In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score, roc_auc_score, mean_squared_error, mean_absolute_error, r2_score

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
df = pd.read_csv("/kaggle/input/mimic-ards/mlhc-ards-cohort-data.csv")

X_cont = df[['max_norepinephrine_equiv', 'avg_norepinephrine_equiv', 'sofa_cardiovascular_avg_meanbp','sofa_cardiovascular_avg_rate_norepinephrine', 'sofa_respiration_avg_pao2fio2ratio', 'sofa_renal_avg_urineoutput', 'sofa_renal_avg_creatinine', 'sofa_cns_avg_gcs', 'first24hr_cardiovascular_rate_norepinephrine', 'first24hr_cardiovascular_meanbp', 'first24hr_respiration_pao2fio2ratio', 'first24hr_renal_urineoutput','first24hr_renal_creatinine', 'first24hr_cns_gcs', 'sofa_cardiovascular_worst_meanbp', 'sofa_cardiovascular_worst_rate_norepinephrine', 'sofa_respiration_worst_pao2fio2ratio', 'sofa_renal_worst_urineoutput', 'sofa_renal_worst_creatinine', 'sofa_cns_worst_gcs','mech_vent_duration_minutes']]

X_bin = df[['other_respiratory_diseases', 'lung_diseases_due_to_external_agents', 'chronic_lower_respiratory_diseases', 'acute_lower_respiratory_infections', 'influenza_pneumonia', 'upper_respiratory_infections']]

C_cont = df[['c_sofa_avg_cardiovascular', 'c_sofa_avg_respiration', 'c_sofa_avg_renal', 'c_sofa_avg_cns', 'c_first24hr_sofa_max_cardiovascular', 'c_first24hr_sofa_max_respiration', 'c_first24hr_sofa_max_renal', 'c_first24hr_sofa_max_cns', 'c_sofa_max_cardiovascular', 'c_sofa_max_respiration', 'c_sofa_max_renal', 'c_sofa_max_cns']]

C_bin = df[['c_svr_resp_comorbidity', 'c_mod_resp_comorbidity']]

LLM_C = df[['ards_detected','aspiration_detected','bilateral_infiltrates_detected', 'cardiac_arrest_detected', 'cardiac_failure_detected', 'pancreatitis_detected','pneumonia_detected','trali_detected']]

Y = df['ARDS_DIAGNOSIS']

In [3]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

class MIMICDataProcessor:
    def __init__(self, file_path, batch_size=64):
        self.file_path = file_path
        self.batch_size = batch_size
        self.x_scaler = MinMaxScaler()
        self.c_scaler = MinMaxScaler()

        # Load and clean data
        self.df = pd.read_csv(file_path)

        self.X_cont = self.df[['max_norepinephrine_equiv', 'avg_norepinephrine_equiv', 'sofa_cardiovascular_avg_meanbp',
                          'sofa_cardiovascular_avg_rate_norepinephrine', 'sofa_respiration_avg_pao2fio2ratio',
                          'sofa_renal_avg_urineoutput', 'sofa_renal_avg_creatinine', 'sofa_cns_avg_gcs', 'first24hr_cardiovascular_meanbp',
                          'sofa_cardiovascular_worst_meanbp',
                          'sofa_cardiovascular_worst_rate_norepinephrine', 'sofa_respiration_worst_pao2fio2ratio',
                          'sofa_renal_worst_urineoutput', 'sofa_cns_worst_gcs',
                          'mech_vent_duration_minutes']]

        self.X_bin = self.df[['other_respiratory_diseases', 'lung_diseases_due_to_external_agents',
                         'chronic_lower_respiratory_diseases', 'acute_lower_respiratory_infections',
                         'influenza_pneumonia', 'upper_respiratory_infections']]

        self.C_cont = self.df[['c_sofa_avg_cardiovascular', 'c_sofa_avg_respiration', 'c_sofa_avg_renal', 'c_sofa_avg_cns',
                          'c_first24hr_sofa_max_cardiovascular', 'c_first24hr_sofa_max_respiration',
                          'c_first24hr_sofa_max_renal', 'c_first24hr_sofa_max_cns', 'c_sofa_max_cardiovascular',
                          'c_sofa_max_respiration', 'c_sofa_max_renal', 'c_sofa_max_cns']]

        self.C_bin = self.df[['c_svr_resp_comorbidity', 'c_mod_resp_comorbidity']]

        self.LLM_C = self.df[['ards_detected', 'aspiration_detected', 'bilateral_infiltrates_detected',
                         'cardiac_arrest_detected', 'cardiac_failure_detected', 'pancreatitis_detected',
                         'pneumonia_detected', 'trali_detected']]

        self.Y = self.df['ARDS_DIAGNOSIS']
        hospital = self.df[['hadm_id']]

        # Split raw data into train/val/test sets
        X_cont_temp, X_cont_test, X_bin_temp, X_bin_test, C_cont_temp, C_cont_test, C_bin_temp, C_bin_test, \
        LLM_C_temp, LLM_C_test, Y_temp, Y_test, hospital_train, hospital_test = train_test_split(
            self.X_cont, self.X_bin, self.C_cont, self.C_bin, self.LLM_C, self.Y, hospital, test_size=0.20, random_state=42)

        X_cont_train, X_cont_val, X_bin_train, X_bin_val, C_cont_train, C_cont_val, C_bin_train, C_bin_val, \
        LLM_C_train, LLM_C_val, Y_train, Y_val = train_test_split(
            X_cont_temp, X_bin_temp, C_cont_temp, C_bin_temp, LLM_C_temp, Y_temp, test_size=0.25, random_state=42)

        # Impute using median from training data only
        cont_cols_x = self.X_cont.columns.tolist()
        cont_cols_c = self.C_cont.columns.tolist()
        
        # Compute medians from training data
        x_medians = X_cont_train.median()
        c_medians = C_cont_train.median()
        
        # Apply to all splits
        X_cont_train = X_cont_train.fillna(x_medians)
        X_cont_val = X_cont_val.fillna(x_medians)
        X_cont_test = X_cont_test.fillna(x_medians)
        
        C_cont_train = C_cont_train.fillna(c_medians)
        C_cont_val = C_cont_val.fillna(c_medians)
        C_cont_test = C_cont_test.fillna(c_medians)

        # Fit scalers on train data only
        self.x_scaler.fit(X_cont_train)
        self.c_scaler.fit(C_cont_train)

        # Transform all splits
        X_train_scaled = self.x_scaler.transform(X_cont_train)
        X_val_scaled = self.x_scaler.transform(X_cont_val)
        X_test_scaled = self.x_scaler.transform(X_cont_test)

        X_train_full = np.concatenate([X_train_scaled, X_bin_train.values.astype(float)], axis=1)
        X_val_full = np.concatenate([X_val_scaled, X_bin_val.values.astype(float)], axis=1)
        X_test_full = np.concatenate([X_test_scaled, X_bin_test.values.astype(float)], axis=1)

        C_train_scaled = self.c_scaler.transform(C_cont_train)
        C_val_scaled = self.c_scaler.transform(C_cont_val)
        C_test_scaled = self.c_scaler.transform(C_cont_test)

        C_train_full = np.concatenate([C_train_scaled, C_bin_train.values.astype(float)], axis=1)
        C_val_full = np.concatenate([C_val_scaled, C_bin_val.values.astype(float)], axis=1)
        C_test_full = np.concatenate([C_test_scaled, C_bin_test.values.astype(float)], axis=1)

        # Convert to tensors
        self.X_tensor_scaled_train = torch.tensor(X_train_full, dtype=torch.float32)
        self.X_tensor_scaled_val = torch.tensor(X_val_full, dtype=torch.float32)
        self.X_tensor_scaled_test = torch.tensor(X_test_full, dtype=torch.float32)

        self.C_tensor_train = torch.tensor(C_train_full, dtype=torch.float32)
        self.C_tensor_val = torch.tensor(C_val_full, dtype=torch.float32)
        self.C_tensor_test = torch.tensor(C_test_full, dtype=torch.float32)

        self.LLM_C_tensor_train = torch.tensor(LLM_C_train.values, dtype=torch.float32)
        self.LLM_C_tensor_val = torch.tensor(LLM_C_val.values, dtype=torch.float32)
        self.LLM_C_tensor_test = torch.tensor(LLM_C_test.values, dtype=torch.float32)

        self.Y_tensor_train = torch.tensor(Y_train.values, dtype=torch.float32)
        self.Y_tensor_val = torch.tensor(Y_val.values, dtype=torch.float32)
        self.Y_tensor_test = torch.tensor(Y_test.values, dtype=torch.float32)

        self.hospital_test = hospital_test

    def create_dataloaders(self):
        train_dataset = self.MIMICDataset(self.X_tensor_scaled_train, self.C_tensor_train,
                                          self.LLM_C_tensor_train, self.Y_tensor_train)
        val_dataset = self.MIMICDataset(self.X_tensor_scaled_val, self.C_tensor_val,
                                        self.LLM_C_tensor_val, self.Y_tensor_val)
        test_dataset = self.MIMICDataset(self.X_tensor_scaled_test, self.C_tensor_test,
                                         self.LLM_C_tensor_test, self.Y_tensor_test)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader, self.hospital_test

    def get_features(self):
        return self.df[['max_norepinephrine_equiv', 'avg_norepinephrine_equiv', 'sofa_cardiovascular_avg_meanbp',
                        'sofa_cardiovascular_avg_rate_norepinephrine', 'sofa_respiration_avg_pao2fio2ratio',
                        'sofa_renal_avg_urineoutput', 'sofa_renal_avg_creatinine', 'sofa_cns_avg_gcs',
                        'first24hr_cardiovascular_rate_norepinephrine', 'first24hr_cardiovascular_meanbp',
                        'first24hr_respiration_pao2fio2ratio', 'first24hr_renal_urineoutput',
                        'first24hr_renal_creatinine', 'first24hr_cns_gcs', 'sofa_cardiovascular_worst_meanbp',
                        'sofa_cardiovascular_worst_rate_norepinephrine', 'sofa_respiration_worst_pao2fio2ratio',
                        'sofa_renal_worst_urineoutput', 'sofa_renal_worst_creatinine', 'sofa_cns_worst_gcs',
                        'mech_vent_duration_minutes']].columns.tolist()

    def get_vanilla_concepts(self):
        return self.df[['c_sofa_avg_cardiovascular', 'c_sofa_avg_respiration', 'c_sofa_avg_renal', 'c_sofa_avg_cns',
                        'c_first24hr_sofa_max_cardiovascular', 'c_first24hr_sofa_max_respiration',
                        'c_first24hr_sofa_max_renal', 'c_first24hr_sofa_max_cns', 'c_sofa_max_cardiovascular',
                        'c_sofa_max_respiration', 'c_sofa_max_renal', 'c_sofa_max_cns',
                        'c_svr_resp_comorbidity', 'c_mod_resp_comorbidity']].columns.tolist()

    def get_llm_concepts(self):
        return self.df[['ards_detected', 'aspiration_detected', 'bilateral_infiltrates_detected',
                        'cardiac_arrest_detected', 'cardiac_failure_detected', 'pancreatitis_detected',
                        'pneumonia_detected', 'trali_detected']].columns.tolist()

    class MIMICDataset(Dataset):
        def __init__(self, x, c, llm_c, y):
            self.x = x
            self.c = c
            self.llm_c = llm_c
            self.y = y

        def __len__(self):
            return len(self.y)

        def __getitem__(self, idx):
            return self.x[idx], self.c[idx], self.llm_c[idx], self.y[idx]

In [4]:
class MultiLabelNN1(nn.Module):
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_labels):
        super(MultiLabelNN1, self).__init__()
        
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        
    def forward(self, x):
        binary_c = self.layer1(x)
        continuous_c = self.layer2(x)
        
        binary_c = torch.sigmoid(binary_c)
        
        y_pred = torch.sigmoid(self.layer3(binary_c)+self.layer4(continuous_c))
        return y_pred, binary_c, continuous_c

class MultiLabelNN2(nn.Module):
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels):
        super(MultiLabelNN2, self).__init__()
        
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        self.layer5 = nn.Linear(num_llm_concepts, num_labels, bias=False)

    def forward(self, x, llm_c):
        binary_c = self.layer1(x)
        continuous_c = self.layer2(x)
        
        binary_c = torch.sigmoid(binary_c)
        
        y_pred = torch.sigmoid(self.layer3(binary_c)+self.layer4(continuous_c)+self.layer5(llm_c))
        return y_pred, binary_c, continuous_c

In [5]:
from copy import deepcopy

class EarlyStopper:
    def __init__(self, patience=25):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.best_model = None

    def should_stop(self, score, model):
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.best_model = deepcopy(model)
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

criterion = nn.BCELoss()

def concept_loss(binary_c_pred, continuous_c_pred, vanilla_c, binary_concept_idx):

    bce_loss = nn.BCELoss()
    mse_loss = nn.MSELoss()

    # print(binary_c_pred.shape,vanilla_c.shape,vanilla_c[:, binary_concept_idx].shape)
    binary_loss = bce_loss(binary_c_pred,vanilla_c[:, binary_concept_idx]) if binary_concept_idx else 0

    continuous_idx = [i for i in range(vanilla_c.shape[1]) if i not in binary_concept_idx]
    continuous_loss = mse_loss(continuous_c_pred,vanilla_c[:, continuous_idx]) if continuous_idx else 0

    return binary_loss + continuous_loss

def train_combined_model(model, x_size, vanilla_c_size, llm_c_size, y_size, learning_rate, epochs, train_loader, val_loader, binary_concept_idx, weight_decay=0.01):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    #scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
    stopper = EarlyStopper(30)
    
    epochs_count = []
    binary_c_predictions, continuous_c_predictions, label_predictions = [], [], []
    binary_c_val_predictions, continuous_c_val_predictions, label_val_predictions = [], [], []
    
    ground_truth_val_c, ground_truth_val_y = [], []

    for epoch in range(epochs):
        #print(f"Epoch {epoch+1}/{epochs}")  
        epochs_count.append(epoch)

        model.train()
        running_loss = 0.0

        for i, batch in enumerate(train_loader):
            x, vanilla_c, llm_c, y = batch
            x, vanilla_c, llm_c, y = x.to(device), vanilla_c.to(device), llm_c.to(device), y.to(device)

            optimizer.zero_grad()
            y_pred, binary_c_pred, continuous_c_pred = model(x, llm_c)

            binary_c_predictions.append(binary_c_pred.detach().cpu().numpy())
            continuous_c_predictions.append(continuous_c_pred.detach().cpu().numpy())
            label_predictions.append(y_pred.detach().cpu().numpy())
            
            c_loss = concept_loss(binary_c_pred, continuous_c_pred, vanilla_c, binary_concept_idx)
            y_loss = criterion(y_pred, y.unsqueeze(1).float())

            loss = y_loss + 0.5*c_loss
            
            loss.backward()
            optimizer.step()

        model.eval()
        
        val_loss = 0.0

        with torch.no_grad():
            for x, vanilla_c, llm_c, y in val_loader:
                x, vanilla_c, llm_c, y = x.to(device), vanilla_c.to(device), llm_c.to(device), y.to(device)

                ground_truth_val_c.append(vanilla_c.cpu())
                ground_truth_val_y.append(y.cpu())

                y_pred, binary_c_pred, continuous_c_pred = model(x, llm_c)

                binary_c_val_predictions.append(binary_c_pred.detach().cpu().numpy())
                continuous_c_val_predictions.append(continuous_c_pred.detach().cpu().numpy())
                label_val_predictions.append(y_pred.detach().cpu().numpy())

                c_loss = concept_loss(binary_c_pred, continuous_c_pred, vanilla_c, binary_concept_idx)
                y_loss = criterion(y_pred, y.unsqueeze(1).float())

                val_loss += y_loss + 0.5*c_loss

        if stopper.should_stop(val_loss,model):
            print("Done")
            # break
        #scheduler.step()

    return model, binary_c_predictions, continuous_c_predictions, label_predictions, binary_c_val_predictions, continuous_c_val_predictions, label_val_predictions, ground_truth_val_c, ground_truth_val_y

def train(model, x_size, c_size, y_size, learning_rate, weight_decay, epochs, train_loader, val_loader, binary_concept_idx):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    #scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    stopper = EarlyStopper(30)
    
    epochs_count = []
    binary_c_predictions, continuous_c_predictions, label_predictions = [], [], []
    binary_c_val_predictions, continuous_c_val_predictions, label_val_predictions = [], [], []
    
    ground_truth_val_c, ground_truth_val_y = [], []

    for epoch in range(epochs):
        #print(f"Epoch {epoch+1}/{epochs}")  # Print current epoch
        epochs_count.append(epoch)

        # Training Loop
        model.train()

        for i, batch in enumerate(train_loader):
            x, c,_, y = batch
            x, c, y = x.to(device), c.to(device), y.to(device)

            optimizer.zero_grad()
            
            y_pred, binary_c_pred, continuous_c_pred = model(x)

            binary_c_predictions.append(binary_c_pred.detach().cpu().numpy())
            continuous_c_predictions.append(continuous_c_pred.detach().cpu().numpy())
            label_predictions.append(y_pred.detach().cpu().numpy())

            c_loss = concept_loss(binary_c_pred, continuous_c_pred, c, binary_concept_idx)
            y_loss = criterion(y_pred, y.unsqueeze(1).float())
                
            loss = y_loss + 0.5*c_loss 

            loss.backward()
            optimizer.step()
        
        model.eval()

        with torch.no_grad():
            val_loss = 0.0
            for x, c,_, y in val_loader:
                
                x, c, y = x.to(device), c.to(device), y.to(device)

                ground_truth_val_c.append(c.cpu())
                ground_truth_val_y.append(y.cpu())

                y_pred, binary_c_pred, continuous_c_pred = model(x)

                binary_c_val_predictions.append(binary_c_pred.detach().cpu().numpy())
                continuous_c_val_predictions.append(continuous_c_pred.detach().cpu().numpy())
                label_val_predictions.append(y_pred.detach().cpu().numpy())
    
                c_loss = concept_loss(binary_c_pred, continuous_c_pred, c, binary_concept_idx)
                y_loss = criterion(y_pred, y.unsqueeze(1).float())
                    
                val_loss += y_loss + 0.5*c_loss 

        if stopper.should_stop(val_loss,model):
            print("Done")
            # break
        #scheduler.step()

    return model, binary_c_predictions, continuous_c_predictions, label_predictions, binary_c_val_predictions, continuous_c_val_predictions, label_val_predictions, ground_truth_val_c, ground_truth_val_y

In [6]:
def evaluate_concept_predictor(ground_truth_c, binary_predictions, continuous_predictions, concept_labels, binary_concept_idx):
    results = []

    total_concepts = len(concept_labels)
    continuous_idx = [i for i in range(total_concepts) if i not in binary_concept_idx]

    for i, label in enumerate(concept_labels):

        true_values = np.concatenate([c[:, i] if isinstance(c, np.ndarray) else c[:, i].numpy() for c in ground_truth_c])
        
        if i in binary_concept_idx:

            predicted_values = np.concatenate([c[:, i - 12] for c in binary_predictions])
            
            predicted_classes = (predicted_values > 0.5).astype(int)

            precision = round(precision_score(true_values, predicted_classes, zero_division=0), 3)
            recall = round(recall_score(true_values, predicted_classes, zero_division=0), 3)
            f1 = round(f1_score(true_values, predicted_classes, zero_division=0), 3)
            accuracy = round(accuracy_score(true_values, predicted_classes), 3)

            results.append({
                "Label": label,
                "Precision": precision,
                "Recall": recall,
                "F1 Score": f1,
                "Accuracy": accuracy,
            })

        else:

            predicted_values = np.concatenate([c[:, i] for c in continuous_predictions])
            
            mse = round(mean_squared_error(true_values, predicted_values), 3)
            mae = round(mean_absolute_error(true_values, predicted_values), 3)
            rmse = round(mean_squared_error(true_values, predicted_values, squared=False), 3)
            r2 = round(r2_score(true_values, predicted_values), 3)

            results.append({
                "Label": label,
                "MSE": mse,
                "MAE": mae,
                "RMSE": rmse,
                "R2": r2
            })

    for label, result in zip(concept_labels, results):
        print(result)
            
    return results

# Label Predictor Evaluation
def evaluate_label_predictor(ground_truth_y, predicted_y):
    true_values = np.concatenate(ground_truth_y)
    predicted_values = np.concatenate(predicted_y).squeeze()

    predicted_classes = (predicted_values > 0.5).astype(int)

    precision = precision_score(true_values, predicted_classes)
    recall = recall_score(true_values, predicted_classes)
    f1 = f1_score(true_values, predicted_classes)
    auc = roc_auc_score(true_values, predicted_classes)
    accuracy = accuracy_score(true_values, predicted_classes)

    results = {
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1,
        "AUC": auc,
        "Accuracy": accuracy
    }
    return pd.DataFrame(results, index=["Metrics"])

def test_model(model, test_loader, binary_concept_idx):
    model.eval()  
    criterion = nn.BCELoss()  

    ground_truth_test_c, ground_truth_test_y = [], []
    binary_c_test_predictions, continuous_c_test_predictions, label_test_predictions = [], [], []

    model.eval()
            
    with torch.no_grad():
        for x, c,_, y in test_loader:
            
            x, c, y = x.to(device), c.to(device), y.to(device)

            ground_truth_test_c.append(c.cpu())
            ground_truth_test_y.append(y.cpu())

            y_pred, binary_c_pred, continuous_c_pred = model(x)

            binary_c_test_predictions.append(binary_c_pred.detach().cpu().numpy())
            continuous_c_test_predictions.append(continuous_c_pred.detach().cpu().numpy())
            label_test_predictions.append(y_pred.detach().cpu().numpy())

            c_loss = concept_loss(binary_c_pred, continuous_c_pred, c, binary_concept_idx)
            y_loss = criterion(y_pred, y.unsqueeze(1).float())
                
            loss = y_loss + 0.5*c_loss 
            
    return ground_truth_test_c, ground_truth_test_y, binary_c_test_predictions, continuous_c_test_predictions, label_test_predictions

def test_combined_model(model, test_loader, binary_concept_idx):
    model.eval()
    criterion = nn.BCELoss()

    ground_truth_test_c, ground_truth_test_y = [], []
    binary_c_test_predictions, continuous_c_test_predictions, label_test_predictions = [], [], []

    with torch.no_grad():
        for x, c, llm_c, y in test_loader:
            x, c, llm_c, y = x.to(device), c.to(device), llm_c.to(device), y.to(device)

            ground_truth_test_c.append(c.cpu())
            ground_truth_test_y.append(y.cpu())
            
            y_pred, binary_c_pred, continuous_c_pred = model(x, llm_c)

            binary_c_test_predictions.append(binary_c_pred.detach().cpu().numpy())
            continuous_c_test_predictions.append(continuous_c_pred.detach().cpu().numpy())
            label_test_predictions.append(y_pred.detach().cpu().numpy())

            c_loss = concept_loss(binary_c_pred, continuous_c_pred, c, binary_concept_idx)
            y_loss = criterion(y_pred, y.unsqueeze(1).float())
                
            loss = y_loss + 0.5*c_loss
    
    return ground_truth_test_c, ground_truth_test_y, binary_c_test_predictions, continuous_c_test_predictions, label_test_predictions

In [7]:
# Load Data
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '/kaggle/input/mimic-ards/mlhc-ards-cohort-data.csv'
processor = MIMICDataProcessor(file_path, batch_size=64)
train_loader, val_loader, test_loader, hospital_test = processor.create_dataloaders()

In [8]:
x_size = processor.X_tensor_scaled_train.shape[1]
c_size = processor.C_tensor_train.shape[1]
y_size = 1
x_to_y_learning_rate = 0.3
weight_decay = 0.0001
epochs = 100
binary_concept_idx = list(range(processor.C_cont.shape[1], processor.C_cont.shape[1] + processor.C_bin.shape[1]))

torch.manual_seed(25)
model = MultiLabelNN1(21,2,12,1).to(device)
model, binary_c_predictions, continuous_c_predictions, label_predictions, binary_c_val_predictions, continuous_c_val_predictions, label_val_predictions, ground_truth_val_c, ground_truth_val_y = train(model, x_size, c_size, y_size, x_to_y_learning_rate, weight_decay, epochs, train_loader, val_loader, binary_concept_idx)

Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done


In [9]:
x_size = processor.X_tensor_scaled_train.shape[1]
vanilla_c_size = processor.C_tensor_train.shape[1]
llm_c_size = processor.LLM_C_tensor_train.shape[1]

y_size = 1
learning_rate = 0.3
epochs = 100
weight_decay = 0.0001

torch.manual_seed(25)
model2 = MultiLabelNN2(21,2,12,8,1).to(device)

model2, binary_c_predictions_llm, continuous_c_predictions_llm, label_predictions_llm, binary_c_val_predictions_llm, continuous_c_val_predictions_llm, label_val_predictions_llm, ground_truth_val_c_llm, ground_truth_val_y_llm = train_combined_model(model2, x_size, vanilla_c_size, llm_c_size, y_size, learning_rate, epochs, train_loader, val_loader, binary_concept_idx, weight_decay)

Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done
Done


In [10]:
torch.save(model.state_dict(),"cbm.pt")
torch.save(model2.state_dict(),"llm_cbm.pt")

In [11]:
c_true, y_true, binary_c_test_predictions, continuous_c_test_predictions, y_pred = test_model(model, test_loader, binary_concept_idx)

c_true_llm, y_true_llm, binary_c_test_predictions_llm, continuous_c_test_predictions_llm, y_pred_llm = test_combined_model(model2, test_loader, binary_concept_idx)

concept_labels = processor.get_vanilla_concepts()

print("Vanilla")
vanilla_c_results = evaluate_concept_predictor(c_true,binary_c_test_predictions, continuous_c_test_predictions,concept_labels, binary_concept_idx)
print("LLM")
llm_c_results = evaluate_concept_predictor(c_true_llm,binary_c_test_predictions_llm, continuous_c_test_predictions_llm, concept_labels, binary_concept_idx)

print(evaluate_label_predictor(y_true, y_pred))
print(evaluate_label_predictor(y_true_llm, y_pred_llm))

Vanilla
{'Label': 'c_sofa_avg_cardiovascular', 'MSE': 0.271, 'MAE': 0.19, 'RMSE': 0.521, 'R2': -3.288}
{'Label': 'c_sofa_avg_respiration', 'MSE': 0.129, 'MAE': 0.271, 'RMSE': 0.359, 'R2': -3.911}
{'Label': 'c_sofa_avg_renal', 'MSE': 1.864, 'MAE': 0.745, 'RMSE': 1.365, 'R2': -25.232}
{'Label': 'c_sofa_avg_cns', 'MSE': 0.02, 'MAE': 0.104, 'RMSE': 0.14, 'R2': -3.042}
{'Label': 'c_first24hr_sofa_max_cardiovascular', 'MSE': 0.421, 'MAE': 0.336, 'RMSE': 0.649, 'R2': -2.045}
{'Label': 'c_first24hr_sofa_max_respiration', 'MSE': 0.16, 'MAE': 0.324, 'RMSE': 0.4, 'R2': -0.61}
{'Label': 'c_first24hr_sofa_max_renal', 'MSE': 0.112, 'MAE': 0.275, 'RMSE': 0.335, 'R2': -0.031}
{'Label': 'c_first24hr_sofa_max_cns', 'MSE': 0.171, 'MAE': 0.242, 'RMSE': 0.413, 'R2': -0.763}
{'Label': 'c_sofa_max_cardiovascular', 'MSE': 0.653, 'MAE': 0.282, 'RMSE': 0.808, 'R2': -4.729}
{'Label': 'c_sofa_max_respiration', 'MSE': 0.043, 'MAE': 0.151, 'RMSE': 0.206, 'R2': 0.511}
{'Label': 'c_sofa_max_renal', 'MSE': 0.26, 'MAE'

In [12]:
from copy import deepcopy
import torch
import torch.nn as nn

class EarlyStopper:
    def __init__(self, patience=25, verbose=False, delta=0, mode='min'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf') # For 'min' mode
        self.delta = delta # Minimum change to qualify as an improvement
        self.best_model_state = None
        self.mode = mode # 'min' for loss, 'max' for accuracy/AUC

    def should_stop(self, score, model):
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(score, model)
        elif (self.mode == 'min' and score < self.best_score - self.delta) or \
             (self.mode == 'max' and score > self.best_score + self.delta):
            self.best_score = score
            self.save_checkpoint(score, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopper counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        
        return self.early_stop

    def save_checkpoint(self, score, model):
        '''Saves model when validation score improves.'''
        if self.verbose:
            if self.mode == 'min':
                print(f'Validation score decreased ({self.val_loss_min:.6f} --> {score:.6f}).  Saving model ...')
            else:
                print(f'Validation score increased ({self.val_loss_min:.6f} --> {score:.6f}).  Saving model ...')
        self.best_model_state = deepcopy(model.state_dict())
        self.val_loss_min = score # Keep track of the actual best score

    def load_best_model(self, model):
        if self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)
            if self.verbose:
                print("Loaded best model state from early stopping.")
        return model

# Define your existing criterion and concept_loss functions (from your notebook)
criterion = nn.BCELoss() # For final label prediction

def concept_loss(binary_c_pred, continuous_c_pred, vanilla_c, binary_concept_idx):
    bce_loss = nn.BCELoss()
    mse_loss = nn.MSELoss()
    
    num_total_concepts = vanilla_c.shape[1]
    
    # Assuming binary_concept_idx contains the indices of binary concepts in the `vanilla_c` tensor
    # And that binary_c_pred corresponds to these.
    binary_loss_val = 0
    if binary_concept_idx and binary_c_pred.numel() > 0 : # Check if there are binary concepts to predict
        true_binary_concepts = vanilla_c[:, binary_concept_idx]
        binary_loss_val = bce_loss(binary_c_pred, true_binary_concepts)

    # Continuous concepts are those not in binary_concept_idx
    continuous_concept_indices = [i for i in range(num_total_concepts) if i not in binary_concept_idx]
    continuous_loss_val = 0
    if continuous_concept_indices and continuous_c_pred.numel() > 0: # Check if there are continuous concepts
        true_continuous_concepts = vanilla_c[:, continuous_concept_indices]
        continuous_loss_val = mse_loss(continuous_c_pred, true_continuous_concepts)
        
    return binary_loss_val + continuous_loss_val

def train_x_to_c_sequential(model, device, train_loader, val_loader, binary_concept_idx, 
                            learning_rate, weight_decay, epochs, patience=10):
    print("--- Stage 1: Training x -> c (Concepts) ---")
    
    # Set requires_grad for layers
    for param in model.layer1.parameters(): param.requires_grad = True
    for param in model.layer2.parameters(): param.requires_grad = True
    for param in model.layer3.parameters(): param.requires_grad = False
    for param in model.layer4.parameters(): param.requires_grad = False

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=learning_rate, weight_decay=weight_decay)
    
    stopper = EarlyStopper(patience=patience, verbose=True, mode='min')

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0
        for x_batch, c_true_batch, _, _ in train_loader: # Use c_true_batch for concept truth
            x_batch, c_true_batch = x_batch.to(device), c_true_batch.to(device)
            
            optimizer.zero_grad()
            _, binary_c_pred, continuous_c_pred = model(x_batch)
            
            loss = concept_loss(binary_c_pred, continuous_c_pred, c_true_batch, binary_concept_idx)
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

        model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for x_batch, c_true_batch, _, _ in val_loader:
                x_batch, c_true_batch = x_batch.to(device), c_true_batch.to(device)
                _, binary_c_pred, continuous_c_pred = model(x_batch)
                loss = concept_loss(binary_c_pred, continuous_c_pred, c_true_batch, binary_concept_idx)
                val_loss_sum += loss.item()
        
        avg_train_loss = train_loss_sum / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} [x->c] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        if stopper.should_stop(avg_val_loss, model):
            print(f"Early stopping at epoch {epoch+1} for x->c training.")
            break
            
    model = stopper.load_best_model(model) # Load the best model state for concept layers
    return model

def train_c_to_y_sequential(model, device, train_loader, val_loader, 
                            learning_rate, weight_decay, epochs, patience=10):
    print("--- Stage 2: Training c -> y (Labels from Predicted Concepts) ---")

    # Set requires_grad for layers
    for param in model.layer1.parameters(): param.requires_grad = False # Freeze concept layers
    for param in model.layer2.parameters(): param.requires_grad = False
    for param in model.layer3.parameters(): param.requires_grad = True
    for param in model.layer4.parameters(): param.requires_grad = True

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=learning_rate, weight_decay=weight_decay)
    
    stopper = EarlyStopper(patience=patience, verbose=True, mode='min') # Loss is minimized

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0
        for x_batch, _, _, y_true_batch in train_loader: # Need x to generate concepts
            x_batch, y_true_batch = x_batch.to(device), y_true_batch.to(device)

            optimizer.zero_grad()
            
            # Get predicted concepts from the frozen x->c part
            # The model's forward pass gives y_pred, binary_c, continuous_c
            # We need to ensure that the y_pred used for loss here is ONLY from the c->y part
            with torch.no_grad(): # Ensure concept layers are not updated during this intermediate step
                 _, binary_c_intermediate, continuous_c_intermediate = model(x_batch)
            
            # Now use these *detached* predicted concepts to predict y through the trainable layers
            # Manually compute y_pred from the specific layers to ensure correct grad flow
            y_pred = torch.sigmoid(model.layer3(binary_c_intermediate.detach()) + \
                                   model.layer4(continuous_c_intermediate.detach()))

            loss = criterion(y_pred, y_true_batch.unsqueeze(1).float())
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

        model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for x_batch, _, _, y_true_batch in val_loader:
                x_batch, y_true_batch = x_batch.to(device), y_true_batch.to(device)
                
                _, binary_c_intermediate, continuous_c_intermediate = model(x_batch)
                y_pred = torch.sigmoid(model.layer3(binary_c_intermediate) + \
                                       model.layer4(continuous_c_intermediate))
                
                loss = criterion(y_pred, y_true_batch.unsqueeze(1).float())
                val_loss_sum += loss.item()

        avg_train_loss = train_loss_sum / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} [c->y] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        if stopper.should_stop(avg_val_loss, model):
            print(f"Early stopping at epoch {epoch+1} for c->y training.")
            break
            
    model = stopper.load_best_model(model) # Load the best model state for label layers
    return model

def train_x_to_c_sequential_combined(model, device, train_loader, val_loader, binary_concept_idx, 
                                     learning_rate, weight_decay, epochs, patience=10):
    print("--- Stage 1 (Combined Model): Training x -> c (Concepts) ---")
    
    # Set requires_grad for layers
    for param in model.layer1.parameters(): param.requires_grad = True
    for param in model.layer2.parameters(): param.requires_grad = True
    for param in model.layer3.parameters(): param.requires_grad = False
    for param in model.layer4.parameters(): param.requires_grad = False
    for param in model.layer5.parameters(): param.requires_grad = False # LLM concept layer also frozen

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=learning_rate, weight_decay=weight_decay)
    
    stopper = EarlyStopper(patience=patience, verbose=True, mode='min')

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0
        for x_batch, c_true_batch, llm_c_batch, _ in train_loader:
            x_batch, c_true_batch, llm_c_batch = x_batch.to(device), c_true_batch.to(device), llm_c_batch.to(device)
            
            optimizer.zero_grad()
            # model's forward for MultiLabelNN2 needs x and llm_c
            _, binary_c_pred, continuous_c_pred = model(x_batch, llm_c_batch) 
            
            loss = concept_loss(binary_c_pred, continuous_c_pred, c_true_batch, binary_concept_idx)
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

        model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for x_batch, c_true_batch, llm_c_batch, _ in val_loader:
                x_batch, c_true_batch, llm_c_batch = x_batch.to(device), c_true_batch.to(device), llm_c_batch.to(device)
                _, binary_c_pred, continuous_c_pred = model(x_batch, llm_c_batch)
                loss = concept_loss(binary_c_pred, continuous_c_pred, c_true_batch, binary_concept_idx)
                val_loss_sum += loss.item()
        
        avg_train_loss = train_loss_sum / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} [x->c Combined] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        if stopper.should_stop(avg_val_loss, model):
            print(f"Early stopping at epoch {epoch+1} for x->c combined training.")
            break
            
    model = stopper.load_best_model(model)
    return model

def train_c_to_y_sequential_combined(model, device, train_loader, val_loader, 
                                     learning_rate, weight_decay, epochs, patience=10):
    print("--- Stage 2 (Combined Model): Training c -> y (Labels from Predicted & LLM Concepts) ---")

    # Set requires_grad for layers
    for param in model.layer1.parameters(): param.requires_grad = False # Freeze concept layers
    for param in model.layer2.parameters(): param.requires_grad = False
    for param in model.layer3.parameters(): param.requires_grad = True
    for param in model.layer4.parameters(): param.requires_grad = True
    for param in model.layer5.parameters(): param.requires_grad = True # LLM concept layer is trained

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=learning_rate, weight_decay=weight_decay)
        
    stopper = EarlyStopper(patience=patience, verbose=True, mode='min')

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0
        for x_batch, _, llm_c_batch, y_true_batch in train_loader:
            x_batch, llm_c_batch, y_true_batch = x_batch.to(device), llm_c_batch.to(device), y_true_batch.to(device)

            optimizer.zero_grad()
            
            with torch.no_grad():
                 _, binary_c_intermediate, continuous_c_intermediate = model(x_batch, llm_c_batch)
            
            # Manually compute y_pred from the specific layers
            y_pred = torch.sigmoid(model.layer3(binary_c_intermediate.detach()) + \
                                   model.layer4(continuous_c_intermediate.detach()) + \
                                   model.layer5(llm_c_batch)) # llm_c_batch is a direct input here

            loss = criterion(y_pred, y_true_batch.unsqueeze(1).float())
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

        model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for x_batch, _, llm_c_batch, y_true_batch in val_loader:
                x_batch, llm_c_batch, y_true_batch = x_batch.to(device), llm_c_batch.to(device), y_true_batch.to(device)
                
                _, binary_c_intermediate, continuous_c_intermediate = model(x_batch, llm_c_batch)
                y_pred = torch.sigmoid(model.layer3(binary_c_intermediate) + \
                                       model.layer4(continuous_c_intermediate) + \
                                       model.layer5(llm_c_batch))
                
                loss = criterion(y_pred, y_true_batch.unsqueeze(1).float())
                val_loss_sum += loss.item()

        avg_train_loss = train_loss_sum / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} [c->y Combined] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        if stopper.should_stop(avg_val_loss, model):
            print(f"Early stopping at epoch {epoch+1} for c->y combined training.")
            break
            
    model = stopper.load_best_model(model)
    return model

In [13]:
# Assuming 'processor', 'train_loader', 'val_loader', 'binary_concept_idx', 'device' are defined
# And MultiLabelNN1, MultiLabelNN2 are defined.

# --- Training MultiLabelNN1 Sequentially ---
x_size = processor.X_tensor_scaled_train.shape[1]
# c_size = processor.C_tensor_train.shape[1] # Not directly used in sequential setup in this manner
# y_size = 1
# For MultiLabelNN1, C_cont is 12, C_bin is 2.
num_continuous_concepts_model1 = processor.C_cont.shape[1]
num_binary_concepts_model1 = processor.C_bin.shape[1]

torch.manual_seed(25)
model1_sequential = MultiLabelNN1(num_features=x_size, 
                                  num_binary_concepts=num_binary_concepts_model1, 
                                  num_continuous_concepts=num_continuous_concepts_model1, 
                                  num_labels=1).to(device)

# Stage 1 for model1
lr_stage1_m1 = 0.1 # Example
wd_stage1_m1 = 0.001
epochs_stage1_m1 = 50 # Example
patience_stage1_m1 = 15
model1_sequential = train_x_to_c_sequential(model1_sequential, device, train_loader, val_loader, 
                                            binary_concept_idx, lr_stage1_m1, wd_stage1_m1, 
                                            epochs_stage1_m1, patience_stage1_m1)

# Stage 2 for model1
lr_stage2_m1 = 0.1 # Example
wd_stage2_m1 = 0.001
epochs_stage2_m1 = 50 # Example
patience_stage2_m1 = 15
model1_sequential = train_c_to_y_sequential(model1_sequential, device, train_loader, val_loader, 
                                            lr_stage2_m1, wd_stage2_m1, epochs_stage2_m1, patience_stage2_m1)

print("MultiLabelNN1 (Standard CBM) trained sequentially.")
# You can now evaluate model1_sequential using your test_model function


# --- Training MultiLabelNN2 Sequentially ---
llm_c_size = processor.LLM_C_tensor_train.shape[1] # Should be 8
num_continuous_concepts_model2 = processor.C_cont.shape[1] # 12
num_binary_concepts_model2 = processor.C_bin.shape[1] # 2


torch.manual_seed(25)
model2_sequential = MultiLabelNN2(num_features=x_size, 
                                  num_binary_concepts=num_binary_concepts_model2, 
                                  num_continuous_concepts=num_continuous_concepts_model2, 
                                  num_llm_concepts=llm_c_size, 
                                  num_labels=1).to(device)

# Stage 1 for model2
lr_stage1_m2 = 0.1
wd_stage1_m2 = 0.001
epochs_stage1_m2 = 50
patience_stage1_m2 = 15
model2_sequential = train_x_to_c_sequential_combined(model2_sequential, device, train_loader, val_loader, 
                                                     binary_concept_idx, lr_stage1_m2, wd_stage1_m2, 
                                                     epochs_stage1_m2, patience_stage1_m2)

# Stage 2 for model2
lr_stage2_m2 = 0.1
wd_stage2_m2 = 0.001
epochs_stage2_m2 = 50
patience_stage2_m2 = 15
model2_sequential = train_c_to_y_sequential_combined(model2_sequential, device, train_loader, val_loader, 
                                                     lr_stage2_m2, wd_stage2_m2, epochs_stage2_m2, patience_stage2_m2)

print("MultiLabelNN2 (Context-Aware CBM) trained sequentially.")
# You can now evaluate model2_sequential using your test_combined_model function
torch.save(model1_sequential.state_dict(), "cbm_sequential.pt")
torch.save(model2_sequential.state_dict(), "llm_cbm_sequential.pt")

--- Stage 1: Training x -> c (Concepts) ---
Epoch 1/50 [x->c] - Train Loss: 0.6065, Val Loss: 0.4235
Validation score decreased (inf --> 0.423480).  Saving model ...
Epoch 2/50 [x->c] - Train Loss: 0.3593, Val Loss: 0.3135
Validation score decreased (0.423480 --> 0.313530).  Saving model ...
Epoch 3/50 [x->c] - Train Loss: 0.2949, Val Loss: 0.2807
Validation score decreased (0.313530 --> 0.280692).  Saving model ...
Epoch 4/50 [x->c] - Train Loss: 0.2722, Val Loss: 0.2715
Validation score decreased (0.280692 --> 0.271479).  Saving model ...
Epoch 5/50 [x->c] - Train Loss: 0.2602, Val Loss: 0.2561
Validation score decreased (0.271479 --> 0.256062).  Saving model ...
Epoch 6/50 [x->c] - Train Loss: 0.2544, Val Loss: 0.2557
Validation score decreased (0.256062 --> 0.255698).  Saving model ...
Epoch 7/50 [x->c] - Train Loss: 0.2534, Val Loss: 0.2500
Validation score decreased (0.255698 --> 0.249955).  Saving model ...
Epoch 8/50 [x->c] - Train Loss: 0.2486, Val Loss: 0.2473
Validation scor

In [14]:
print("--- Evaluating Sequentially Trained MultiLabelNN1 (Standard CBM) ---")
# Use the existing test_model function (defined in Cell 6 of your notebook)
c_true_seq, y_true_seq, binary_c_test_pred_seq, continuous_c_test_pred_seq, y_pred_seq = test_model(
    model1_sequential, 
    test_loader, 
    binary_concept_idx
)

print("\n--- Evaluating Sequentially Trained MultiLabelNN2 (Context-Aware CBM) ---")
# Use the existing test_combined_model function (defined in Cell 6 of your notebook)
c_true_llm_seq, y_true_llm_seq, binary_c_test_pred_llm_seq, continuous_c_test_pred_llm_seq, y_pred_llm_seq = test_combined_model(
    model2_sequential, 
    test_loader, 
    binary_concept_idx
)

# Get concept labels (same as before)
concept_labels = processor.get_vanilla_concepts()

print("\n--- Concept Evaluation for Sequentially Trained MultiLabelNN1 (Standard CBM) ---")
# Use the existing evaluate_concept_predictor function (defined in Cell 6 of your notebook)
vanilla_c_results_seq = evaluate_concept_predictor(
    ground_truth_c=c_true_seq, 
    binary_predictions=binary_c_test_pred_seq, 
    continuous_predictions=continuous_c_test_pred_seq, 
    concept_labels=concept_labels, 
    binary_concept_idx=binary_concept_idx
)

print("\n--- Concept Evaluation for Sequentially Trained MultiLabelNN2 (Context-Aware CBM) ---")
# Use the existing evaluate_concept_predictor function (defined in Cell 6 of your notebook)
llm_c_results_seq = evaluate_concept_predictor(
    ground_truth_c=c_true_llm_seq, 
    binary_predictions=binary_c_test_pred_llm_seq, 
    continuous_predictions=continuous_c_test_pred_llm_seq, 
    concept_labels=concept_labels, 
    binary_concept_idx=binary_concept_idx
)

print("\n--- Label Prediction Evaluation for Sequentially Trained MultiLabelNN1 (Standard CBM) ---")
# Use the existing evaluate_label_predictor function (defined in Cell 6 of your notebook)
label_results_model1_seq_df = evaluate_label_predictor(y_true_seq, y_pred_seq)
print(label_results_model1_seq_df)

print("\n--- Label Prediction Evaluation for Sequentially Trained MultiLabelNN2 (Context-Aware CBM) ---")
# Use the existing evaluate_label_predictor function (defined in Cell 6 of your notebook)
label_results_model2_seq_df = evaluate_label_predictor(y_true_llm_seq, y_pred_llm_seq)
print(label_results_model2_seq_df)

--- Evaluating Sequentially Trained MultiLabelNN1 (Standard CBM) ---

--- Evaluating Sequentially Trained MultiLabelNN2 (Context-Aware CBM) ---

--- Concept Evaluation for Sequentially Trained MultiLabelNN1 (Standard CBM) ---
{'Label': 'c_sofa_avg_cardiovascular', 'MSE': 0.098, 'MAE': 0.165, 'RMSE': 0.314, 'R2': -0.556}
{'Label': 'c_sofa_avg_respiration', 'MSE': 0.007, 'MAE': 0.064, 'RMSE': 0.085, 'R2': 0.724}
{'Label': 'c_sofa_avg_renal', 'MSE': 0.031, 'MAE': 0.121, 'RMSE': 0.175, 'R2': 0.567}
{'Label': 'c_sofa_avg_cns', 'MSE': 0.005, 'MAE': 0.046, 'RMSE': 0.072, 'R2': -0.07}
{'Label': 'c_first24hr_sofa_max_cardiovascular', 'MSE': 0.185, 'MAE': 0.324, 'RMSE': 0.43, 'R2': -0.342}
{'Label': 'c_first24hr_sofa_max_respiration', 'MSE': 0.115, 'MAE': 0.305, 'RMSE': 0.339, 'R2': -0.161}
{'Label': 'c_first24hr_sofa_max_renal', 'MSE': 0.055, 'MAE': 0.182, 'RMSE': 0.235, 'R2': 0.493}
{'Label': 'c_first24hr_sofa_max_cns', 'MSE': 0.083, 'MAE': 0.221, 'RMSE': 0.287, 'R2': 0.148}
{'Label': 'c_sofa_

In [15]:
Vanilla
{'Label': 'c_sofa_avg_cardiovascular', 'MSE': 0.271, 'MAE': 0.19, 'RMSE': 0.521, 'R2': -3.288}
{'Label': 'c_sofa_avg_respiration', 'MSE': 0.129, 'MAE': 0.271, 'RMSE': 0.359, 'R2': -3.911}
{'Label': 'c_sofa_avg_renal', 'MSE': 1.864, 'MAE': 0.745, 'RMSE': 1.365, 'R2': -25.232}
{'Label': 'c_sofa_avg_cns', 'MSE': 0.02, 'MAE': 0.104, 'RMSE': 0.14, 'R2': -3.042}
{'Label': 'c_first24hr_sofa_max_cardiovascular', 'MSE': 0.421, 'MAE': 0.336, 'RMSE': 0.649, 'R2': -2.045}
{'Label': 'c_first24hr_sofa_max_respiration', 'MSE': 0.16, 'MAE': 0.324, 'RMSE': 0.4, 'R2': -0.61}
{'Label': 'c_first24hr_sofa_max_renal', 'MSE': 0.112, 'MAE': 0.275, 'RMSE': 0.335, 'R2': -0.031}
{'Label': 'c_first24hr_sofa_max_cns', 'MSE': 0.171, 'MAE': 0.242, 'RMSE': 0.413, 'R2': -0.763}
{'Label': 'c_sofa_max_cardiovascular', 'MSE': 0.653, 'MAE': 0.282, 'RMSE': 0.808, 'R2': -4.729}
{'Label': 'c_sofa_max_respiration', 'MSE': 0.043, 'MAE': 0.151, 'RMSE': 0.206, 'R2': 0.511}
{'Label': 'c_sofa_max_renal', 'MSE': 0.26, 'MAE': 0.405, 'RMSE': 0.51, 'R2': -0.696}
{'Label': 'c_sofa_max_cns', 'MSE': 0.026, 'MAE': 0.129, 'RMSE': 0.161, 'R2': 0.791}
{'Label': 'c_svr_resp_comorbidity', 'Precision': 0.969, 'Recall': 0.95, 'F1 Score': 0.96, 'Accuracy': 0.98}
{'Label': 'c_mod_resp_comorbidity', 'Precision': 0.771, 'Recall': 0.955, 'F1 Score': 0.853, 'Accuracy': 0.777}
LLM
{'Label': 'c_sofa_avg_cardiovascular', 'MSE': 0.261, 'MAE': 0.174, 'RMSE': 0.511, 'R2': -3.129}
{'Label': 'c_sofa_avg_respiration', 'MSE': 0.009, 'MAE': 0.073, 'RMSE': 0.094, 'R2': 0.665}
{'Label': 'c_sofa_avg_renal', 'MSE': 0.052, 'MAE': 0.159, 'RMSE': 0.228, 'R2': 0.268}
{'Label': 'c_sofa_avg_cns', 'MSE': 0.005, 'MAE': 0.051, 'RMSE': 0.067, 'R2': 0.069}
{'Label': 'c_first24hr_sofa_max_cardiovascular', 'MSE': 0.434, 'MAE': 0.347, 'RMSE': 0.659, 'R2': -2.145}
{'Label': 'c_first24hr_sofa_max_respiration', 'MSE': 0.11, 'MAE': 0.269, 'RMSE': 0.332, 'R2': -0.111}
{'Label': 'c_first24hr_sofa_max_renal', 'MSE': 0.648, 'MAE': 0.393, 'RMSE': 0.805, 'R2': -4.941}
{'Label': 'c_first24hr_sofa_max_cns', 'MSE': 0.096, 'MAE': 0.256, 'RMSE': 0.31, 'R2': 0.007}
{'Label': 'c_sofa_max_cardiovascular', 'MSE': 0.55, 'MAE': 0.345, 'RMSE': 0.742, 'R2': -3.827}
{'Label': 'c_sofa_max_respiration', 'MSE': 0.068, 'MAE': 0.141, 'RMSE': 0.261, 'R2': 0.216}
{'Label': 'c_sofa_max_renal', 'MSE': 0.171, 'MAE': 0.215, 'RMSE': 0.414, 'R2': -0.117}
{'Label': 'c_sofa_max_cns', 'MSE': 0.038, 'MAE': 0.139, 'RMSE': 0.195, 'R2': 0.693}
{'Label': 'c_svr_resp_comorbidity', 'Precision': 0.97, 'Recall': 0.96, 'F1 Score': 0.965, 'Accuracy': 0.982}
{'Label': 'c_mod_resp_comorbidity', 'Precision': 0.894, 'Recall': 0.955, 'F1 Score': 0.923, 'Accuracy': 0.893}
         Precision    Recall  F1 Score       AUC  Accuracy
Metrics   0.740113  0.629808  0.680519  0.689221  0.685422
         Precision    Recall  F1 Score       AUC  Accuracy
Metrics   0.771784  0.894231  0.828508  0.796842  0.803069

IndentationError: unexpected indent (3480086513.py, line 31)

In [16]:
# --- Get predictions from sequentially trained model1_sequential ---
print("Getting predictions from sequentially trained model1_sequential...")
c_true_m1_seq, y_true_m1_seq, binary_c_test_pred_m1_seq, cont_c_test_pred_m1_seq, y_pred_m1_seq_raw = test_model(
    model1_sequential, 
    test_loader, 
    binary_concept_idx
)

# Concatenate batch predictions for model1_sequential
binary_c_pred_m1_seq_full = np.concatenate(binary_c_test_pred_m1_seq, axis=0)
continuous_c_pred_m1_seq_full = np.concatenate(cont_c_test_pred_m1_seq, axis=0)
c_pred_m1_seq = np.concatenate([continuous_c_pred_m1_seq_full, binary_c_pred_m1_seq_full], axis=1) # Order: continuous then binary for c_pred
# Note: Ensure the order matches how test_concepts is structured (continuous then binary).
# Based on your MIMICDataProcessor: self.C_tensor_test was np.concatenate([C_test_scaled, C_bin_test.values.astype(float)], axis=1)
# And binary_concept_idx marks the start of binary concepts.
# The test_model returns binary_c_pred and continuous_c_pred separately.
# If test_concepts is [C_cont_test, C_bin_test], then:
# c_pred_m1_seq needs to be [continuous_c_pred_m1_seq_full, binary_c_pred_m1_seq_full]
# However, your original c_pred (Cell 25) was [binary_c_pred, continuous_c_pred].
# Let's stick to your original notebook's convention for c_pred: [binary, continuous]
# The evaluate_concept_predictor expects this based on its indexing: `c[:, i - 12]` for binary.
# This means binary concepts (2) come first, then continuous (12).
# So, when using in evaluate_concept_predictor or interventions, ensure indices are correct.
# For consistency with original intervention code (e.g. b = b[-2:] for binary), we'll assume c_pred is [continuous, binary]
# And binary_concept_idx will be [12, 13] relative to the 14 vanilla concepts.

# Let's re-verify the order of c_pred in your original notebook (Cell 25):
# c_pred = np.concatenate([binary_c_pred, continuous_c_pred], axis=1)
# This means binary concepts are indeed first in c_pred.
# test_model returns binary_c_pred, continuous_c_pred.
# So c_pred_m1_seq should be:
c_pred_m1_seq = np.concatenate([binary_c_pred_m1_seq_full, continuous_c_pred_m1_seq_full], axis=1)


y_pred_final_m1_seq = np.concatenate(y_pred_m1_seq_raw, axis=0).reshape(-1)
y_true_final_m1_seq = np.concatenate(y_true_m1_seq, axis=0).reshape(-1)

# --- Get predictions from sequentially trained model2_sequential ---
print("Getting predictions from sequentially trained model2_sequential...")
c_true_m2_seq, y_true_m2_seq, binary_c_test_pred_m2_seq, cont_c_test_pred_m2_seq, y_pred_m2_seq_raw = test_combined_model(
    model2_sequential, 
    test_loader, 
    binary_concept_idx
)

# Concatenate batch predictions for model2_sequential
binary_c_pred_m2_seq_full = np.concatenate(binary_c_test_pred_m2_seq, axis=0)
continuous_c_pred_m2_seq_full = np.concatenate(cont_c_test_pred_m2_seq, axis=0)
# Following original notebook's c_pred_llm structure:
c_pred_m2_seq = np.concatenate([binary_c_pred_m2_seq_full, continuous_c_pred_m2_seq_full], axis=1)

y_pred_final_m2_seq = np.concatenate(y_pred_m2_seq_raw, axis=0).reshape(-1)
y_true_final_m2_seq = np.concatenate(y_true_m2_seq, axis=0).reshape(-1)


# --- Define FP/FN indices for sequentially trained models ---
threshold = 0.5

# For model1_sequential
predicted_labels_m1_seq = (y_pred_final_m1_seq >= threshold).astype(int)
fp_idx_m1_seq = np.where((y_true_final_m1_seq == 0) & (predicted_labels_m1_seq == 1))[0]
fn_idx_m1_seq = np.where((y_true_final_m1_seq == 1) & (predicted_labels_m1_seq == 0))[0]

# For model2_sequential
predicted_labels_m2_seq = (y_pred_final_m2_seq >= threshold).astype(int)
fp_idx_m2_seq = np.where((y_true_final_m2_seq == 0) & (predicted_labels_m2_seq == 1))[0]
fn_idx_m2_seq = np.where((y_true_final_m2_seq == 1) & (predicted_labels_m2_seq == 0))[0]

# --- Prepare test_features, test_concepts, test_llm_concepts (as in Cell 26 of original) ---
# This part is independent of model training and should be the same
_test_features_batches = []
_test_concepts_batches = [] # Ground truth vanilla concepts
_test_llm_concepts_batches = [] # Ground truth LLM concepts

for data_batch in test_loader:
    _test_features_batches.append(data_batch[0].cpu().numpy())
    _test_concepts_batches.append(data_batch[1].cpu().numpy()) # This is C_tensor_test
    _test_llm_concepts_batches.append(data_batch[2].cpu().numpy()) # This is LLM_C_tensor_test

test_features = np.concatenate(_test_features_batches, axis=0)
test_concepts = np.concatenate(_test_concepts_batches, axis=0) # Ground truth vanilla concepts C_tensor_test
test_llm_concepts = np.concatenate(_test_llm_concepts_batches, axis=0) # Ground truth LLM concepts

# Confirm structure of test_concepts. From MIMICDataProcessor:
# C_tensor_test = torch.tensor(C_test_full, dtype=torch.float32)
# C_test_full = np.concatenate([C_test_scaled, C_bin_test.values.astype(float)], axis=1)
# This means continuous concepts (12) are first, then binary concepts (2) in test_concepts.
# binary_concept_idx correctly identifies indices [12, 13] in test_concepts.

# IMPORTANT: The c_pred_m1_seq and c_pred_m2_seq were created as [binary_pred, continuous_pred].
# The intervention code in the original notebook that uses c_pred (e.g. `b = c_pred[a].copy(); b = b[-2:]`)
# implies that for those specific loops, it expects binary concepts to be at the end of the slice.
# Let's adjust c_pred_m1_seq and c_pred_m2_seq to be [continuous_pred, binary_pred]
# to match the structure of test_concepts (ground truth) for easier indexing in interventions.
# This also means the evaluate_concept_predictor indexing `i - 12` for binary predictions will be correct
# if binary_predictions (the argument) contains only binary concepts and binary_concept_idx is [12,13].

# Re-evaluate c_pred structure for intervention consistency:
# Original Int1 intervention code: `b = c_pred[a].copy(); b = b[-2:]` implies binary is at the end.
# test_model and test_combined_model return: binary_c_pred, continuous_c_pred.
# If c_pred is [binary_c_pred_full, continuous_c_pred_full], then b[-2:] is wrong.
# If c_pred is [continuous_c_pred_full, binary_c_pred_full], then b[-2:] is correct for binary intervention.

# Let's define c_pred_m1_seq and c_pred_m2_seq to match the structure of test_concepts:
# [continuous_pred_concepts (12), binary_pred_concepts (2)]
c_pred_m1_seq = np.concatenate([continuous_c_pred_m1_seq_full, binary_c_pred_m1_seq_full], axis=1)
c_pred_m2_seq = np.concatenate([continuous_c_pred_m2_seq_full, binary_c_pred_m2_seq_full], axis=1)

# Now, when intervening on c_pred_m1_seq or c_pred_m2_seq:
# - Continuous concepts are at indices 0-11.
# - Binary concepts are at indices 12-13.
# The original intervention code (e.g., Cell 28) `b = b[-2:]` for binary intervention on `c_pred` and
# `b = b[:-2]` for continuous intervention on `c_pred` will now work directly with this structure.
# The `intervention_concept_idxs` like `[12,13]` or `[0,1,...,11]` will correctly map.

print("Test predictions and FP/FN indices for sequentially trained models are ready.")

Getting predictions from sequentially trained model1_sequential...
Getting predictions from sequentially trained model2_sequential...
Test predictions and FP/FN indices for sequentially trained models are ready.


In [17]:
from copy import deepcopy # Already imported in your notebook

# --- Intervention Models Definition (Adapted from Cells 27, 31, 34, 47) ---
# Ensure these are defined with num_features = x_size (e.g., 21)

class Int1(nn.Module): # For base CBM, intervening on binary concepts
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_labels):
        super(Int1, self).__init__()
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False) # Will be frozen or unused
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        
    def forward(self, x, binary_int): # x is for layer2 to get continuous_c
        # binary_c = self.layer1(x) # Original concept prediction is bypassed
        continuous_c = self.layer2(x)
        # binary_int is already sigmoid-activated if it's a prediction, or 0/1 if true
        y_pred = torch.sigmoid(self.layer3(binary_int) + self.layer4(continuous_c))
        return y_pred

class Int2(nn.Module): # For base CBM, intervening on continuous concepts
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_labels):
        super(Int2, self).__init__()
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False) # Will be frozen or unused
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        
    def forward(self, x, continuous_int): # x is for layer1 to get binary_c
        binary_c = self.layer1(x)
        # continuous_c = self.layer2(x) # Original concept prediction is bypassed
        binary_c = torch.sigmoid(binary_c)
        y_pred = torch.sigmoid(self.layer3(binary_c) + self.layer4(continuous_int))
        return y_pred

class Int3(nn.Module): # For enhanced CBM, intervening on binary concepts
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels):
        super(Int3, self).__init__()
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        self.layer5 = nn.Linear(num_llm_concepts, num_labels, bias=False)

    def forward(self, x, binary_int, llm_c):
        continuous_c = self.layer2(x)
        y_pred = torch.sigmoid(self.layer3(binary_int) + self.layer4(continuous_c) + self.layer5(llm_c))
        return y_pred

class Int4(nn.Module): # For enhanced CBM, intervening on continuous concepts
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels):
        super(Int4, self).__init__()
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        self.layer5 = nn.Linear(num_llm_concepts, num_labels, bias=False)

    def forward(self, x, continuous_int, llm_c):
        binary_c = torch.sigmoid(self.layer1(x))
        y_pred = torch.sigmoid(self.layer3(binary_c) + self.layer4(continuous_int) + self.layer5(llm_c))
        return y_pred

class Int5(nn.Module): # For enhanced CBM, intervening on LLM concepts
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels):
        super(Int5, self).__init__()
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False)
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        self.layer5 = nn.Linear(num_llm_concepts, num_labels, bias=False)

    def forward(self, x, llm_int):
        binary_c = torch.sigmoid(self.layer1(x))
        continuous_c = self.layer2(x)
        y_pred = torch.sigmoid(self.layer3(binary_c) + self.layer4(continuous_c) + self.layer5(llm_int))
        return y_pred
        
class Int6(nn.Module): # For enhanced CBM, intervening on all concept types
    def __init__(self, num_features, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels):
        super(Int6, self).__init__()
        # These layers will effectively be from the c->y part of the model
        self.layer1 = nn.Linear(num_features, num_binary_concepts, bias=False)    # Not directly used for x->c_bin if binary_int provided
        self.layer2 = nn.Linear(num_features, num_continuous_concepts, bias=False) # Not directly used for x->c_cont if continuous_int provided
        self.layer3 = nn.Linear(num_binary_concepts, num_labels, bias=False)
        self.layer4 = nn.Linear(num_continuous_concepts, num_labels, bias=False)
        self.layer5 = nn.Linear(num_llm_concepts, num_labels, bias=False)

    def forward(self, x, binary_int, continuous_int, llm_int): # x is technically not needed if all concepts are intervened
        # binary_c = self.layer1(x) # Bypassed
        # continuous_c = self.layer2(x) # Bypassed
        y_pred = torch.sigmoid(self.layer3(binary_int) + self.layer4(continuous_int) + self.layer5(llm_int))
        return y_pred

# --- Instantiate and Load Weights for Intervention Models (Sequential) ---
x_size = processor.X_tensor_scaled_train.shape[1] # Should be 21
num_binary_concepts = processor.C_bin.shape[1]    # Should be 2
num_continuous_concepts = processor.C_cont.shape[1] # Should be 12
num_llm_concepts = processor.LLM_C.shape[1]       # Should be 8
num_labels = 1

# For model1_sequential (base CBM)
model_int1_seq = Int1(x_size, num_binary_concepts, num_continuous_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int1_seq.parameters(), model1_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int1_seq.eval()

model_int2_seq = Int2(x_size, num_binary_concepts, num_continuous_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int2_seq.parameters(), model1_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int2_seq.eval()

# For model2_sequential (enhanced CBM)
model_int3_seq = Int3(x_size, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int3_seq.parameters(), model2_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int3_seq.eval()

model_int4_seq = Int4(x_size, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int4_seq.parameters(), model2_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int4_seq.eval()

model_int5_seq = Int5(x_size, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int5_seq.parameters(), model2_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int5_seq.eval()

model_int6_seq = Int6(x_size, num_binary_concepts, num_continuous_concepts, num_llm_concepts, num_labels).to(device)
for p_new, p_orig in zip(model_int6_seq.parameters(), model2_sequential.parameters()):
    p_new.data = deepcopy(p_orig.data)
model_int6_seq.eval()

print("Intervention models initialized and weights copied from sequentially trained models.")

Intervention models initialized and weights copied from sequentially trained models.


In [20]:
import numpy as np
import torch

# Ensure binary_concept_idx refers to indices within the 14 vanilla concepts.
# If C_cont has 12 features, binary_concept_idx = [12, 13]
# num_continuous_concepts = processor.C_cont.shape[1] # Should be 12
# binary_intervention_indices_in_vanilla = list(range(num_continuous_concepts, num_continuous_concepts + num_binary_concepts))
# continuous_intervention_indices_in_vanilla = list(range(num_continuous_concepts))

# For clarity, let's define these based on your processor
num_total_vanilla_concepts = processor.C_tensor_train.shape[1] # Should be 14
continuous_intervention_indices_in_vanilla = list(range(processor.C_cont.shape[1])) # Indices 0-11
binary_intervention_indices_in_vanilla = list(range(processor.C_cont.shape[1], num_total_vanilla_concepts)) # Indices 12-13

# --- Interventions on Base CBM (model1_sequential) ---

print("\n###########################################################################")
print("--- Interventions on Base CBM (model1_sequential) ---")
print("###########################################################################")

# --- Binary Concepts Interventions for model1_sequential (Adapted from Cells 28, 29, 30) ---

# Intervention Type: Replacement with True Values
print("\n--- Binary Concepts: Replacement with True Values (model1_sequential) ---")
for int_idx_in_vanilla in binary_intervention_indices_in_vanilla:
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# All binary concepts together with true values
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for bc_idx in binary_intervention_indices_in_vanilla:
        intervened_vanilla_concepts[bc_idx] = test_concepts[a][bc_idx]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: All Binary Concepts (True): Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for bc_idx in binary_intervention_indices_in_vanilla:
        intervened_vanilla_concepts[bc_idx] = test_concepts[a][bc_idx]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() < 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: All Binary Concepts (True): Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Mean Values
print("\n--- Binary Concepts: Replacement with Mean Values (model1_sequential) ---")
mean_true_binary_concepts = np.mean(test_concepts[:, binary_intervention_indices_in_vanilla], axis=0)
for i, int_idx_in_vanilla in enumerate(binary_intervention_indices_in_vanilla):
    mean_val_for_concept = mean_true_binary_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")
    
    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# All binary concepts together with mean values
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for i, bc_idx in enumerate(binary_intervention_indices_in_vanilla):
        intervened_vanilla_concepts[bc_idx] = mean_true_binary_concepts[i]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: All Binary Concepts (Mean): Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for i, bc_idx in enumerate(binary_intervention_indices_in_vanilla):
        intervened_vanilla_concepts[bc_idx] = mean_true_binary_concepts[i]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() < 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: All Binary Concepts (Mean): Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Median Values
print("\n--- Binary Concepts: Replacement with Median Values (model1_sequential) ---")
median_true_binary_concepts = np.median(test_concepts[:, binary_intervention_indices_in_vanilla], axis=0)
for i, int_idx_in_vanilla in enumerate(binary_intervention_indices_in_vanilla):
    median_val_for_concept = median_true_binary_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")
    
# All binary concepts together with median values
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for i, bc_idx in enumerate(binary_intervention_indices_in_vanilla):
        intervened_vanilla_concepts[bc_idx] = median_true_binary_concepts[i]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: All Binary Concepts (Median): Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for i, bc_idx in enumerate(binary_intervention_indices_in_vanilla):
        intervened_vanilla_concepts[bc_idx] = median_true_binary_concepts[i]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int1_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device))
    if p_int.item() < 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: All Binary Concepts (Median): Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# --- Continuous Concepts Interventions for model1_sequential (Adapted from Cells 31, 32, 33) ---
# Using model_int2_seq

# Intervention Type: Replacement with True Values
print("\n--- Continuous Concepts: Replacement with True Values (model1_sequential) ---")
for int_idx_in_vanilla in continuous_intervention_indices_in_vanilla: # Indices 0-11
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# All continuous concepts together with true values (Example for selected indices from your original notebook)
selected_continuous_for_fn = [2, 3, 5, 10] # Example indices from your original cell 31
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for cc_idx in selected_continuous_for_fn:
        if cc_idx in continuous_intervention_indices_in_vanilla:
             intervened_vanilla_concepts[cc_idx] = test_concepts[a][cc_idx]
    continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([continuous_slice_for_int_model]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: Selected Continuous Concepts {selected_continuous_for_fn} (True): Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

selected_continuous_for_fp = [1, 2, 11] # Example indices
count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m1_seq:
    intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
    for cc_idx in selected_continuous_for_fp:
        if cc_idx in continuous_intervention_indices_in_vanilla:
            intervened_vanilla_concepts[cc_idx] = test_concepts[a][cc_idx]
    continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([continuous_slice_for_int_model]).float().to(device))
    if p_int.item() < 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: Selected Continuous Concepts {selected_continuous_for_fp} (True): Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Mean Values
print("\n--- Continuous Concepts: Replacement with Mean Values (model1_sequential) ---")
mean_true_continuous_concepts = np.mean(test_concepts[:, continuous_intervention_indices_in_vanilla], axis=0)
for i, int_idx_in_vanilla in enumerate(continuous_intervention_indices_in_vanilla):
    mean_val_for_concept = mean_true_continuous_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# Intervention Type: Replacement with Median Values
print("\n--- Continuous Concepts: Replacement with Median Values (model1_sequential) ---")
median_true_continuous_concepts = np.median(test_concepts[:, continuous_intervention_indices_in_vanilla], axis=0)
for i, int_idx_in_vanilla in enumerate(continuous_intervention_indices_in_vanilla):
    median_val_for_concept = median_true_continuous_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m1_seq:
        intervened_vanilla_concepts = c_pred_m1_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int2_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device))
        if p_int.item() < 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


print("\n###########################################################################")
print("--- Interventions on Enhanced CBM (model2_sequential) ---")
print("###########################################################################")

# --- Binary Concepts Interventions for model2_sequential (Adapted from Cells 35, 36, 37) ---
# Using model_int3_seq
# FN/FP indices are fn_idx_m2_seq, fp_idx_m2_seq
# Predicted concepts c_pred_m2_seq
# True concepts test_concepts, test_llm_concepts

# Intervention Type: Replacement with True Values
print("\n--- Binary Concepts: Replacement with True Values (model2_sequential) ---")
for int_idx_in_vanilla in binary_intervention_indices_in_vanilla:
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")
    
    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1 # Note: original notebook had <=0.5 for FP correction here
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# All binary concepts together with true values
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m2_seq:
    intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
    for bc_idx in binary_intervention_indices_in_vanilla:
        intervened_vanilla_concepts[bc_idx] = test_concepts[a][bc_idx]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device),
                               torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: All Binary Concepts (True): Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m2_seq:
    intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
    for bc_idx in binary_intervention_indices_in_vanilla:
        intervened_vanilla_concepts[bc_idx] = test_concepts[a][bc_idx]
    binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
    with torch.no_grad():
        p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([binary_slice_for_int_model]).float().to(device),
                               torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
    if p_int.item() <= 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: All Binary Concepts (True): Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Mean Values
print("\n--- Binary Concepts: Replacement with Mean Values (model2_sequential) ---")
# mean_true_binary_concepts is already calculated
for i, int_idx_in_vanilla in enumerate(binary_intervention_indices_in_vanilla):
    mean_val_for_concept = mean_true_binary_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Median Values
print("\n--- Binary Concepts: Replacement with Median Values (model2_sequential) ---")
# median_true_binary_concepts is already calculated
for i, int_idx_in_vanilla in enumerate(binary_intervention_indices_in_vanilla):
    median_val_for_concept = median_true_binary_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        binary_slice_for_int_model = intervened_vanilla_concepts[binary_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int3_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([binary_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# --- Continuous Concepts Interventions for model2_sequential (Adapted from Cells 38, 39, 40) ---
# Using model_int4_seq

# Intervention Type: Replacement with True Values
print("\n--- Continuous Concepts: Replacement with True Values (model2_sequential) ---")
for int_idx_in_vanilla in continuous_intervention_indices_in_vanilla: # Indices 0-11
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = test_concepts[a][int_idx_in_vanilla]
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla}: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# Intervention Type: Replacement with Mean Values
print("\n--- Continuous Concepts: Replacement with Mean Values (model2_sequential) ---")
# mean_true_continuous_concepts is already calculated
for i, int_idx_in_vanilla in enumerate(continuous_intervention_indices_in_vanilla):
    mean_val_for_concept = mean_true_continuous_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = mean_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Mean: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# Intervention Type: Replacement with Median Values
print("\n--- Continuous Concepts: Replacement with Median Values (model2_sequential) ---")
# median_true_continuous_concepts is already calculated
for i, int_idx_in_vanilla in enumerate(continuous_intervention_indices_in_vanilla):
    median_val_for_concept = median_true_continuous_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_vanilla_concepts = c_pred_m2_seq[a].copy()
        intervened_vanilla_concepts[int_idx_in_vanilla] = median_val_for_concept
        continuous_slice_for_int_model = intervened_vanilla_concepts[continuous_intervention_indices_in_vanilla]
        with torch.no_grad():
            p_int = model_int4_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([continuous_slice_for_int_model]).float().to(device),
                                   torch.tensor(test_llm_concepts[a:a+1]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on vanilla_concept_idx {int_idx_in_vanilla} with Median: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# --- LLM Concepts Interventions for model2_sequential (Adapted from Cells 43, 45, 46) ---
# Using model_int5_seq
llm_concept_indices = list(range(num_llm_concepts)) # Indices 0-7 for LLM concepts

# Intervention Type: Manual Changes (1.0 for FN, 0.0 for FP)
print("\n--- LLM Concepts: Manual Changes (1.0 for FN, 0.0 for FP) (model2_sequential) ---")
for llm_idx in llm_concept_indices:
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy() # Start with true LLM concepts
        intervened_llm[llm_idx] = 1.0 # Intervene one LLM concept to 1.0
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on LLM_concept_idx {llm_idx} to 1.0: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy()
        intervened_llm[llm_idx] = 0.0 # Intervene one LLM concept to 0.0
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on LLM_concept_idx {llm_idx} to 0.0: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# Example: All specified LLM concepts for FNs changed to 1.0 (from your original Cell 43)
fn_llm_indices_to_one = [0, 1, 5] # Example from your original notebook
count_fn_corrected, total_fn = 0, 0
for a in fn_idx_m2_seq:
    intervened_llm = test_llm_concepts[a].copy()
    for idx_to_one in fn_llm_indices_to_one:
        intervened_llm[idx_to_one] = 1.0
    with torch.no_grad():
        p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([intervened_llm]).float().to(device))
    if p_int.item() >= 0.5: count_fn_corrected += 1
    total_fn += 1
print(f"FNs: LLM concepts {fn_llm_indices_to_one} to 1.0: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

# Example: All specified LLM concepts for FPs changed to 0.0 (from your original Cell 43)
fp_llm_indices_to_zero = [0, 2, 6] # Example from your original notebook
count_fp_corrected, total_fp = 0, 0
for a in fp_idx_m2_seq:
    intervened_llm = test_llm_concepts[a].copy()
    for idx_to_zero in fp_llm_indices_to_zero:
        intervened_llm[idx_to_zero] = 0.0
    with torch.no_grad():
        p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                               torch.tensor([intervened_llm]).float().to(device))
    if p_int.item() <= 0.5: count_fp_corrected += 1
    total_fp += 1
print(f"FPs: LLM concepts {fp_llm_indices_to_zero} to 0.0: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Mean Values
print("\n--- LLM Concepts: Replacement with Mean Values (model2_sequential) ---")
mean_true_llm_concepts = np.mean(test_llm_concepts, axis=0)
for i, llm_idx in enumerate(llm_concept_indices):
    mean_val_for_llm_concept = mean_true_llm_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy() # Start with true LLM concepts
        intervened_llm[llm_idx] = mean_val_for_llm_concept
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on LLM_concept_idx {llm_idx} with Mean: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy()
        intervened_llm[llm_idx] = mean_val_for_llm_concept
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on LLM_concept_idx {llm_idx} with Mean: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")


# Intervention Type: Replacement with Median Values
print("\n--- LLM Concepts: Replacement with Median Values (model2_sequential) ---")
median_true_llm_concepts = np.median(test_llm_concepts, axis=0)
for i, llm_idx in enumerate(llm_concept_indices):
    median_val_for_llm_concept = median_true_llm_concepts[i]
    count_fn_corrected, total_fn = 0, 0
    for a in fn_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy()
        intervened_llm[llm_idx] = median_val_for_llm_concept
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() >= 0.5: count_fn_corrected += 1
        total_fn += 1
    print(f"FNs: Intervened on LLM_concept_idx {llm_idx} with Median: Corrections: {count_fn_corrected/total_fn if total_fn > 0 else 'N/A'} ({count_fn_corrected}/{total_fn})")

    count_fp_corrected, total_fp = 0, 0
    for a in fp_idx_m2_seq:
        intervened_llm = test_llm_concepts[a].copy()
        intervened_llm[llm_idx] = median_val_for_llm_concept
        with torch.no_grad():
            p_int = model_int5_seq(torch.tensor(test_features[a:a+1]).float().to(device),
                                   torch.tensor([intervened_llm]).float().to(device))
        if p_int.item() <= 0.5: count_fp_corrected += 1
        total_fp += 1
    print(f"FPs: Intervened on LLM_concept_idx {llm_idx} with Median: Corrections: {count_fp_corrected/total_fp if total_fp > 0 else 'N/A'} ({count_fp_corrected}/{total_fp})")

# --- Interventions on all three kinds of concepts simultaneously (Adapted from Cell 48, 49, 50) ---
# Using model_int6_seq

# Example: Best of ground truth replacements (FNs)
print("\n--- All Concepts (GT): Best of Ground Truth Replacements for FNs (model2_sequential) ---")
count_fn_corr_all_gt, total_fn_all_gt = 0, 0
for a in fn_idx_m2_seq:
    # Vanilla concepts part - from c_pred_m2_seq, intervene select with true
    intervened_vanilla = c_pred_m2_seq[a].copy()
    intervened_vanilla[12] = test_concepts[a][12] # Example: binary concept 0 (idx 12)
    intervened_vanilla[7] = test_concepts[a][7]   # Example: continuous concept 7 (idx 7)
    
    binary_slice = intervened_vanilla[binary_intervention_indices_in_vanilla]
    continuous_slice = intervened_vanilla[continuous_intervention_indices_in_vanilla]

    # LLM concepts part - from test_llm_concepts, intervene select manually
    intervened_llm = test_llm_concepts[a].copy()
    intervened_llm[0] = 1.0 # Example: LLM concept 0 to 1.0
    intervened_llm[1] = 1.0 # Example: LLM concept 1 to 1.0
    intervened_llm[5] = 1.0 # Example: LLM concept 5 to 1.0

    with torch.no_grad():
        p_int = model_int6_seq(
            torch.tensor(test_features[a:a+1]).float().to(device),
            torch.tensor([binary_slice]).float().to(device),
            torch.tensor([continuous_slice]).float().to(device),
            torch.tensor([intervened_llm]).float().to(device)
        )
    if p_int.item() >= 0.5: count_fn_corr_all_gt += 1
    total_fn_all_gt += 1
print(f"FNs: All Concepts (Selected GT & Manual LLM): Corrections: {count_fn_corr_all_gt/total_fn_all_gt if total_fn_all_gt > 0 else 'N/A'} ({count_fn_corr_all_gt}/{total_fn_all_gt})")

# Example: Best of ground truth replacements (FPs)
print("\n--- All Concepts (GT): Best of Ground Truth Replacements for FPs (model2_sequential) ---")
count_fp_corr_all_gt, total_fp_all_gt = 0, 0
for a in fp_idx_m2_seq:
    intervened_vanilla = c_pred_m2_seq[a].copy()
    intervened_vanilla[12] = test_concepts[a][12] # Example
    intervened_vanilla[5] = test_concepts[a][5]   # Example
    intervened_vanilla[6] = test_concepts[a][6]   # Example
    intervened_vanilla[9] = test_concepts[a][9]   # Example
    
    binary_slice = intervened_vanilla[binary_intervention_indices_in_vanilla]
    continuous_slice = intervened_vanilla[continuous_intervention_indices_in_vanilla]

    intervened_llm = test_llm_concepts[a].copy()
    intervened_llm[0] = 0.0 # Example
    intervened_llm[2] = 0.0 # Example
    intervened_llm[6] = 0.0 # Example

    with torch.no_grad():
        p_int = model_int6_seq(
            torch.tensor(test_features[a:a+1]).float().to(device),
            torch.tensor([binary_slice]).float().to(device),
            torch.tensor([continuous_slice]).float().to(device),
            torch.tensor([intervened_llm]).float().to(device)
        )
    if p_int.item() < 0.5: count_fp_corr_all_gt += 1
    total_fp_all_gt += 1
print(f"FPs: All Concepts (Selected GT & Manual LLM): Corrections: {count_fp_corr_all_gt/total_fp_all_gt if total_fp_all_gt > 0 else 'N/A'} ({count_fp_corr_all_gt}/{total_fp_all_gt})")

# ... Continue adapting for mean and median replacements for all concepts, similar to your Cells 49 and 50 ...


###########################################################################
--- Interventions on Base CBM (model1_sequential) ---
###########################################################################

--- Binary Concepts: Replacement with True Values (model1_sequential) ---
FNs: Intervened on vanilla_concept_idx 12: Corrections: 0.0 (0/55)
FPs: Intervened on vanilla_concept_idx 12: Corrections: 0.11428571428571428 (8/70)
FNs: Intervened on vanilla_concept_idx 13: Corrections: 0.01818181818181818 (1/55)
FPs: Intervened on vanilla_concept_idx 13: Corrections: 0.014285714285714285 (1/70)
FNs: All Binary Concepts (True): Corrections: 0.0 (0/55)
FPs: All Binary Concepts (True): Corrections: 0.12857142857142856 (9/70)

--- Binary Concepts: Replacement with Mean Values (model1_sequential) ---
FNs: Intervened on vanilla_concept_idx 12 with Mean: Corrections: 0.38181818181818183 (21/55)
FPs: Intervened on vanilla_concept_idx 12 with Mean: Corrections: 0.07142857142857142 (5/70)
FNs: Inte