## CAPRI-CT: Causal Analysis and Predictive Reasoning for Image Quality Optimization in Computed Tomography

### Ablation Testing:

Ablation testing is like a "what-if" experiment for AI models. It involves removing one part of the model or one input at a time to see how much it affects the results. If the model performs much worse after removing something, that part was probably important. This helps us understand which features or components matter most for the model’s predictions.

So, In this ablation testing , we will creating different variants of our CAPRI-CT Causal Aware model in terms of the input parameters passed to the model and provide the analysis results at the end. 

Below are the variants tested in this notebook: 

- Image(i) + Metadata (v, t, a) (full model)
- Image(i) + Metadata (v, t, a, noise) (Robustness)
- Image(i) + Metadata (v, a) (without Current per time)
- Image(i) + Metadata (t, a) (without voltage)
- Image(i) only (without voltage, current, agent)
- Image(i) + Metadata (v, t) (without agent)

This Ablation testing can also be termed as Causal graph perturbation where we are removing each input node in the causal graph and test our assumptions made in the beginning. 

<img src="..\images\capri-ct-dag.png" alt="Causal Graph" width="600"/>

Lets look into the each variant one by one !

In [None]:
#######################################################################
# Importing the required libraries for Ablation testing
#######################################################################

import os
import pandas as pd
from PIL import Image
import random
import numpy as np
import torch
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, WeightedRandomSampler, Subset, ConcatDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error,mean_squared_error, r2_score
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau


In [None]:
##########################################################################################
# Below is the Class CapriCTDataset 
# Combining the CT image with the metadata using Dataset package
# for SNR prediction
##########################################################################################

class CapriCTDataset(Dataset):
    """
    A custom PyTorch Dataset for CT scan images combined with metadata and target SNR values.

    Each sample consists of a grayscale CT image, encoded metadata (Voltage, Time, Contrast Agent),
    and a corresponding SNR value.

    Parameters
    ----------
    metadata_csv : str
        Path to the CSV file containing metadata for each image.
        Expected columns: 'Filename', 'Voltage', 'Time', 'Classification', 'SNR'.
    
    img_folder_path : str
        Path to the folder containing the CT scan image files.
    
    transform : callable, optional
        Optional transform to be applied on a sample image (e.g., resizing, normalization).

    Attributes
    ----------
    agent_dict : dict
        Mapping of contrast agent labels to integer indices.
    
    voltage_dict : dict
        Mapping of voltage levels to integer indices.
    
    time_dict : dict
        Mapping of scan time values to integer indices.
    """

    def __init__(self, metadata_csv, img_folder_path, transform=None):
        self.img_data = pd.read_csv(metadata_csv)
        self.img_folder = img_folder_path
        self.transform = transform

        # Mappings
        self.agent_dict = {'Iodine': 0, 'BiNPs 50nm': 1, 'BiNPs 100nm': 2}
        self.voltage_dict = {80: 0, 100: 1, 120: 2, 140: 3}
        self.time_dict = {215: 0, 430: 1}

    def __len__(self):
        return len(self.img_data)

    def __getitem__(self, idx):
        row = self.img_data.iloc[idx]

        # Load and transform image
        img_path = os.path.join(self.img_folder, row['Filename'])
        img = Image.open(img_path).convert('L')  
        image = self.transform(img) if self.transform else img

        # Convert categorical fields to indices
        voltage_idx = torch.tensor(self.voltage_dict[row['Voltage']], dtype=torch.long)
        time_idx = torch.tensor(self.time_dict[row['Time']], dtype=torch.long)
        agent_idx = torch.tensor(self.agent_dict[row['Classification']], dtype=torch.long)

        # Target SNR
        snr = torch.tensor(row['SNR'], dtype=torch.float32)

        return image, voltage_idx, time_idx, agent_idx, snr


In [None]:
###############################################################################
# Below is our CAPRI-CT Causal VAE model
###############################################################################

class CapriCTCausalVAEModel(nn.Module):
    """
    A Causal Variational Autoencoder (VAE) model for predicting SNR from CT images and metadata.

    This model combines image features with categorical metadata embeddings (Voltage, Time, Agent)
    and uses a VAE structure to learn a low-dimensional latent representation for regression tasks.

    Parameters
    ----------
    latent_dim : int, optional
        Dimensionality of the latent space. Default is 64.
    voltage_classes : int, optional
        Number of distinct voltage classes. Default is 4.
    time_classes : int, optional
        Number of distinct time classes. Default is 2.
    agent_classes : int, optional
        Number of distinct contrast agent classes. Default is 3.
    """

    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        self.voltage_embed = nn.Embedding(voltage_classes, 16)
        self.time_embed = nn.Embedding(time_classes, 8)
        self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Fully connected layers for VAE
        conv_output_size = 128 * 3 * 3  # Assuming 9x9 input image
        embed_size = 16 + 8 + 12

        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img, voltage_idx, time_idx, agent_idx):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices of shape (B,).
        time_idx : Tensor
            Time class indices of shape (B,).
        agent_idx : Tensor
            Contrast agent class indices of shape (B,).

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h_combined)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, voltage_idx, time_idx, agent_idx):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z_combined)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.
        """
        mu, logvar = self.encode(img, voltage_idx, time_idx, agent_idx)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z, voltage_idx, time_idx, agent_idx)
        return snr_pred, mu, logvar


In [None]:
###############################################################################
# Below is one of CAPRI-CT Causal VAE model variant without metadata
###############################################################################

class CapriCTWithoutEmbedModel(nn.Module):
    """
    VAE-based model for predicting SNR from CT images without using metadata embeddings.

    The model uses a CNN encoder to extract features from the image, maps them to a 
    latent space, and decodes the latent vector to predict the SNR value.

    Args:
        latent_dim (int): Dimensionality of the latent space (default: 64).
        voltage_classes (int): Number of voltage levels (unused).
        time_classes (int): Number of time levels (unused).
        agent_classes (int): Number of contrast agent types (unused).

    Methods:
        encode(img): Encodes the input image into mean and log variance.
        reparameterize(mu, logvar): Samples from the latent space using the reparameterization trick.
        decode(z): Decodes the latent vector into SNR.
        forward(img, voltage_idx, time_idx, agent_idx): Full forward pass (metadata unused).
    """
    
    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        #self.voltage_embed = nn.Embedding(voltage_classes, 16)
        #self.time_embed = nn.Embedding(time_classes, 8)
        #self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder with more depth and dropout
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Flattened size for input image 9x9 → (128, 3, 3)
        conv_output_size = 128 * 3 * 3
        embed_size = 0

        # Latent space
        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        #v_emb = self.voltage_embed(voltage_idx)
        #t_emb = self.time_embed(time_idx)
        #a_emb = self.agent_embed(agent_idx)
        #emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        #h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        #v_emb = self.voltage_embed(voltage_idx)
        #t_emb = self.time_embed(time_idx)
        #_emb = self.agent_embed(agent_idx)
        #emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        #z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices. Unused
        time_idx : Tensor
            Time class indices. Unused
        agent_idx : Tensor
            Contrast agent class indices. Unused

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.

        Note:
        Even the voltage_idx, time_idx, agent_idx are passed , they are unused
        """
        mu, logvar = self.encode(img)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z)
        return snr_pred, mu, logvar


In [None]:
###############################################################################
# Below is one of CAPRI-CT Causal VAE model variant without Time input
###############################################################################

class CapriCTWithoutTimeModel(nn.Module):
    """
    VAE-based model for predicting SNR from CT images using voltage and agent embeddings,
    excluding time as an input feature.

    The model encodes image features and metadata (voltage, agent) into a latent space 
    and decodes it to predict the SNR value.

    Args:
        latent_dim (int): Size of the latent vector (default: 64).
        voltage_classes (int): Number of voltage categories.
        time_classes (int): Number of time categories (unused).
        agent_classes (int): Number of contrast agent types.

    Methods:
        encode(img, voltage_idx, time_idx, agent_idx): Encodes the image and metadata 
            (excluding time) into latent mean and log variance.
        reparameterize(mu, logvar): Samples latent vector using reparameterization trick.
        decode(z, voltage_idx, time_idx, agent_idx): Decodes latent vector and metadata 
            (excluding time) into predicted SNR.
        forward(img, voltage_idx, time_idx, agent_idx): Executes full VAE pipeline.
    """
    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        self.voltage_embed = nn.Embedding(voltage_classes, 16)
        self.time_embed = nn.Embedding(time_classes, 8)
        self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder with more depth and dropout
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Flattened size for input image 9x9 → (128, 3, 3)
        conv_output_size = 128 * 3 * 3
        embed_size = 16 + 12

        # Latent space
        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img, voltage_idx, time_idx, agent_idx):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices of shape (B,).
        time_idx : Tensor
            Time class indices of shape (B,). Unused
        agent_idx : Tensor
            Contrast agent class indices of shape (B,).

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        v_emb = self.voltage_embed(voltage_idx)
        #t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, a_emb], dim=1)

        h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h_combined)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, voltage_idx, time_idx, agent_idx):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices. Unused
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        v_emb = self.voltage_embed(voltage_idx)
        #t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, a_emb], dim=1)

        z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z_combined)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.

        """

        mu, logvar = self.encode(img, voltage_idx, time_idx, agent_idx)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z, voltage_idx, time_idx, agent_idx)
        return snr_pred, mu, logvar


In [None]:
###############################################################################
# Below is one of CAPRI-CT Causal VAE model variant without Voltage input
###############################################################################

class CapriCTWithoutVoltageModel(nn.Module):
    """
    VAE-based model for predicting SNR from CT images using time and agent embeddings,
    excluding voltage as an input feature.

    The model encodes image features and metadata (time, agent) into a latent space 
    and decodes it to predict the SNR value.

    Args:
        latent_dim (int): Size of the latent vector (default: 64).
        voltage_classes (int): Number of voltage categories. (unused)
        time_classes (int): Number of time categories 
        agent_classes (int): Number of contrast agent types.

    Methods:
        encode(img, voltage_idx, time_idx, agent_idx): Encodes the image and metadata 
            (excluding voltage) into latent mean and log variance.
        reparameterize(mu, logvar): Samples latent vector using reparameterization trick.
        decode(z, voltage_idx, time_idx, agent_idx): Decodes latent vector and metadata 
            (excluding voltage) into predicted SNR.
        forward(img, voltage_idx, time_idx, agent_idx): Executes full VAE pipeline.
    """

    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        self.voltage_embed = nn.Embedding(voltage_classes, 16)
        self.time_embed = nn.Embedding(time_classes, 8)
        self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder with more depth and dropout
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Flattened size for input image 9x9 → (128, 3, 3)
        conv_output_size = 128 * 3 * 3
        embed_size = 8 + 12

        # Latent space
        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img, voltage_idx, time_idx, agent_idx):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices of shape (B,). Unused
        time_idx : Tensor
            Time class indices of shape (B,).
        agent_idx : Tensor
            Contrast agent class indices of shape (B,).

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        #v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([ t_emb, a_emb], dim=1)

        h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h_combined)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, voltage_idx, time_idx, agent_idx):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).
        voltage_idx : Tensor
            Voltage class indices. Unused
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        #v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([ t_emb, a_emb], dim=1)

        z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z_combined)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.

        """

        mu, logvar = self.encode(img, voltage_idx, time_idx, agent_idx)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z, voltage_idx, time_idx, agent_idx)
        return snr_pred, mu, logvar


In [None]:
###############################################################################
# Below is one of CAPRI-CT Causal VAE model variant without Agent input
###############################################################################

class CapriCTWithoutAgentModel(nn.Module):
    """
    VAE-based model for predicting SNR from CT images using voltage and time embeddings,
    excluding agent as an input feature.

    The model encodes image features and metadata (voltage, time) into a latent space 
    and decodes it to predict the SNR value.

    Args:
        latent_dim (int): Size of the latent vector (default: 64).
        voltage_classes (int): Number of voltage categories.
        time_classes (int): Number of time categories .
        agent_classes (int): Number of contrast agent types (unused)

    Methods:
        encode(img, voltage_idx, time_idx, agent_idx): Encodes the image and metadata 
            (excluding agent) into latent mean and log variance.
        reparameterize(mu, logvar): Samples latent vector using reparameterization trick.
        decode(z, voltage_idx, time_idx, agent_idx): Decodes latent vector and metadata 
            (excluding agent) into predicted SNR.
        forward(img, voltage_idx, time_idx, agent_idx): Executes full VAE pipeline.
    """
    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        self.voltage_embed = nn.Embedding(voltage_classes, 16)
        self.time_embed = nn.Embedding(time_classes, 8)
        self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder with more depth and dropout
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Flattened size for input image 9x9 → (128, 3, 3)
        conv_output_size = 128 * 3 * 3
        embed_size = 16 + 8

        # Latent space
        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img, voltage_idx, time_idx, agent_idx):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices of shape (B,).
        time_idx : Tensor
            Time class indices of shape (B,).
        agent_idx : Tensor
            Contrast agent class indices of shape (B,). Unused

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        #a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, t_emb], dim=1)

        h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h_combined)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, voltage_idx, time_idx, agent_idx):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices. Unused

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        #a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, t_emb], dim=1)

        z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z_combined)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.

        """

        mu, logvar = self.encode(img, voltage_idx, time_idx, agent_idx)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z, voltage_idx, time_idx, agent_idx)
        return snr_pred, mu, logvar


In [None]:
###############################################################################
# Below is one of CAPRI-CT Causal VAE model variant with added noise
###############################################################################

class CapriCTRobustCheckModel(nn.Module):
    """
    A Causal Variational Autoencoder (VAE) model for predicting SNR from CT images and metadata + noise.

    This model combines image features with categorical metadata embeddings (Voltage, Time, Agent)
    and uses a VAE structure to learn a low-dimensional latent representation for regression tasks.

    Parameters
    ----------
    latent_dim : int, optional
        Dimensionality of the latent space. Default is 64.
    voltage_classes : int, optional
        Number of distinct voltage classes. Default is 4.
    time_classes : int, optional
        Number of distinct time classes. Default is 2.
    agent_classes : int, optional
        Number of distinct contrast agent classes. Default is 3.
    """
    def __init__(self, latent_dim=64, voltage_classes=4, time_classes=2, agent_classes=3):
        super().__init__()

        # Embeddings
        self.voltage_embed = nn.Embedding(voltage_classes, 16)
        self.time_embed = nn.Embedding(time_classes, 8)
        self.agent_embed = nn.Embedding(agent_classes, 12)

        # CNN Encoder with more depth and dropout
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout_cnn = nn.Dropout2d(0.25)

        # Flattened size for input image 9x9 → (128, 3, 3)
        conv_output_size = 128 * 3 * 3
        embed_size = 16 + 8 + 12

        # Latent space
        self.fc1 = nn.Linear(conv_output_size + embed_size, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(0.3)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim + embed_size, 128)
        self.bn_dec1 = nn.BatchNorm1d(128)
        self.decoder_fc2 = nn.Linear(128, 64)
        self.bn_dec2 = nn.BatchNorm1d(64)
        self.decoder_out = nn.Linear(64, 1)

    def encode(self, img, voltage_idx, time_idx, agent_idx):
        """
        Encodes the input image and metadata into latent mean and log-variance.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices of shape (B,).
        time_idx : Tensor
            Time class indices of shape (B,).
        agent_idx : Tensor
            Contrast agent class indices of shape (B,).

        Returns
        -------
        mu : Tensor
            Mean of the latent Gaussian distribution.
        logvar : Tensor
            Log-variance of the latent Gaussian distribution.
        """
        h = F.relu(self.bn1(self.conv1(img)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = self.dropout_cnn(h)
        h = torch.flatten(h, start_dim=1)

        # Embeddings
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)
        emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        h_combined = torch.cat([h, emb], dim=1)
        h1 = F.relu(self.bn_fc1(self.fc1(h_combined)))
        h1 = self.dropout_fc(h1)

        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using standard normal.

        Parameters
        ----------
        mu : Tensor
            Mean of the latent distribution.
        logvar : Tensor
            Log-variance of the latent distribution.

        Returns
        -------
        Tensor
            Sampled latent vector z.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, voltage_idx, time_idx, agent_idx):
        """
        Decodes the latent vector and metadata into the predicted SNR.

        Parameters
        ----------
        z : Tensor
            Latent vector of shape (B, latent_dim).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        Tensor
            Predicted SNR values of shape (B,).
        """
        v_emb = self.voltage_embed(voltage_idx)
        t_emb = self.time_embed(time_idx)
        a_emb = self.agent_embed(agent_idx)

        # Adding Noise for Robustness Check
        v_emb = v_emb + torch.randn_like(v_emb) * 0.5
        t_emb = t_emb + torch.randn_like(t_emb) * 0.5
        a_emb = a_emb + torch.randn_like(a_emb) * 0.5
        
        emb = torch.cat([v_emb, t_emb, a_emb], dim=1)

        z_combined = torch.cat([z, emb], dim=1)
        h = F.relu(self.bn_dec1(self.decoder_fc1(z_combined)))
        h = F.relu(self.bn_dec2(self.decoder_fc2(h)))
        snr_pred = self.decoder_out(h)
        return snr_pred.squeeze(1)

    def forward(self, img, voltage_idx, time_idx, agent_idx):
        """
        Forward pass of the model. Encodes input, samples from latent space, and decodes to predict SNR.

        Parameters
        ----------
        img : Tensor
            Input image tensor of shape (B, 1, H, W).
        voltage_idx : Tensor
            Voltage class indices.
        time_idx : Tensor
            Time class indices.
        agent_idx : Tensor
            Contrast agent class indices.

        Returns
        -------
        snr_pred : Tensor
            Predicted SNR values of shape (B,).
        mu : Tensor
            Mean of latent distribution.
        logvar : Tensor
            Log-variance of latent distribution.

        """

        mu, logvar = self.encode(img, voltage_idx, time_idx, agent_idx)
        z = self.reparameterize(mu, logvar)
        snr_pred = self.decode(z, voltage_idx, time_idx, agent_idx)
        return snr_pred, mu, logvar


In [None]:
############################################################################################
# get_data_loaders function:
# Loads a CT dataset with images and metadata.
# Applies quantile binning to stratify based on SNR values.
# Splits the dataset into training and validation sets.
# Augments extreme SNR samples to better handle edge cases.
# Applies image transformations.
# Balances the training set using weighted sampling to address SNR distribution imbalance.
############################################################################################

def get_data_loaders(seed, batch_size=16, n_bins=30, extreme_percentile=5):
    """
    Prepares stratified, augmented, and balanced DataLoaders for training and validation.

    Parameters
    ----------
    seed : int
        Random seed for reproducibility.
    
    batch_size : int, optional
        Batch size for both training and validation DataLoaders. Default is 16.
    
    n_bins : int, optional
        Number of quantile bins to use for stratified splitting and weighted sampling. Default is 30.
    
    extreme_percentile : float, optional
        Percentile threshold to define "extreme" SNR values (low and high ends). Default is 5.

    Returns
    -------
    train_loader : DataLoader
        PyTorch DataLoader for the training set with data augmentation and weighted sampling.
    
    val_loader : DataLoader
        PyTorch DataLoader for the validation set with deterministic sampling and no augmentation.

    """

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    generator = torch.Generator().manual_seed(seed)

    # Transforms
    train_transform = transforms.Compose([
        transforms.Resize((9, 9)),
        transforms.RandomRotation(degrees=10),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor()
    ])
    val_transform = transforms.Compose([
        transforms.Resize((9, 9)),
        transforms.ToTensor()
    ])

    # Load full dataset
    base_path = Path("../dataset")

    full_dataset = CapriCTDataset(
        metadata_csv= base_path / "final_dataset.csv",
        img_folder_path= base_path / "img" ,
        transform=None
    )

    # Get all SNR values
    all_snr = np.array([full_dataset[i][4] for i in range(len(full_dataset))])

    # Quantile binning for stratification
    snr_bins = pd.qcut(all_snr, q=n_bins, labels=False, duplicates='drop')

    # Train/Val split
    train_indices, val_indices = train_test_split(
        np.arange(len(full_dataset)),
        test_size=0.2,
        random_state=seed,
        stratify=snr_bins
    )

    # Build subsets
    class TransformedSubset(torch.utils.data.Dataset):
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform

        def __getitem__(self, idx):
            img, voltage, time, agent, snr = self.subset[idx]
            if self.transform:
                img = self.transform(img)
            return img, voltage, time, agent, snr

        def __len__(self):
            return len(self.subset)

    train_subset = Subset(full_dataset, train_indices)
    val_subset = Subset(full_dataset, val_indices)

    val_dataset = TransformedSubset(val_subset, val_transform)

    # --- Duplicate extreme values dynamically ---
    train_snr = np.array([full_dataset[i][4] for i in train_indices])
    lower_thresh = np.percentile(train_snr, extreme_percentile)
    upper_thresh = np.percentile(train_snr, 100 - extreme_percentile)

    # Identify extreme samples
    extreme_indices = [i for i in train_indices if full_dataset[i][4] < lower_thresh or full_dataset[i][4] > upper_thresh]

    duplicated_extremes = Subset(full_dataset, extreme_indices)
    duplicated_extremes = TransformedSubset(duplicated_extremes, train_transform)

    # Wrap the base train set with transform
    train_dataset_base = TransformedSubset(train_subset, train_transform)

    # Combine datasets: original + duplicated extremes
    combined_train_dataset = ConcatDataset([train_dataset_base,
                                            duplicated_extremes
                                            ] )

    # --- Weighted Sampling ---
    combined_snr = np.array([combined_train_dataset[i][4] for i in range(len(combined_train_dataset))])
    combined_bins = pd.qcut(combined_snr, q=n_bins, labels=False, duplicates='drop')
    bin_counts = np.bincount(combined_bins)
    bin_weights = 1.0 / bin_counts
    weights = [bin_weights[b] for b in combined_bins]

    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    train_loader = torch.utils.data.DataLoader(combined_train_dataset, batch_size=batch_size, sampler=sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


In [None]:
###############################################################
# Setting the seed value for each training loop
###############################################################

def set_seed(seed):
    """
    Sets the random seed across Python, NumPy, and PyTorch (CPU and GPU) for reproducibility.

    Parameters
    ----------
    seed : int
        The seed value to ensure deterministic behavior across runs.
    """
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
#############################################################################
# This method calculates the vae LOSS with the preds and targets provided
#############################################################################

def vae_loss(snr_pred, snr_true, mu, logvar):
    """
    Computes the total loss for a Variational Autoencoder (VAE), combining 
    reconstruction loss and KL divergence.

    Parameters
    ----------
    snr_pred : Tensor
        Predicted SNR values from the decoder. Shape: (B,)
    
    snr_true : Tensor
        Ground-truth SNR values. Shape: (B,)
    
    mu : Tensor
        Latent mean vector from the encoder. Shape: (B, latent_dim)
    
    logvar : Tensor
        Log-variance vector from the encoder. Shape: (B, latent_dim)

    Returns
    -------
    total_loss : Tensor
        Sum of reconstruction loss and KL divergence.
    
    recon_loss : Tensor
        Mean squared error (MSE) between predicted and true SNR values.
    
    kld : Tensor
        KL divergence between the learned latent distribution and standard normal.
    
    Notes
    -----
    - KL divergence is scaled by the batch size to ensure stability across varying batch sizes.
    - This function assumes `snr_pred` and `snr_true` are both 1D tensors.
    """
    
    recon_loss = F.mse_loss(snr_pred, snr_true, reduction='mean')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / snr_true.size(0)
    return recon_loss + kld, recon_loss, kld

In [None]:
##########################################################################
# Below validate model method calculates the model performance against 
# validation dataset and returns the metrics 
##########################################################################

def validate_model(model, val_loader, device):
    """
    Evaluates the VAE model on a validation set and computes loss and performance metrics.

    Parameters
    ----------
    model : torch.nn.Module
        The trained VAE model to be evaluated.
    
    val_loader : DataLoader
        PyTorch DataLoader containing the validation dataset.
    
    device : torch.device
        The device on which computation will be performed (e.g., 'cuda' or 'cpu').

    Returns
    -------
    val_loss : float
        Average Smooth L1 loss over the entire validation set.
    
    r2 : float
        Coefficient of determination (R² score) between predicted and true SNR values.
    
    rmse : float
        Root Mean Squared Error (RMSE) between predicted and true SNR values.

    Notes
    -----
    - The model is set to evaluation mode during validation (`model.eval()`).
    - No gradient computation is performed (`torch.no_grad()` context).
    - `Smooth L1 loss` is used as a robust regression loss function.
    - Predictions and ground truths are collected across all batches to compute R² and RMSE.
    """

    model.eval()
    val_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for image, voltage, time, agent_idx, snr_true in val_loader:
            image = image.to(device)
            voltage = voltage.to(device)
            time = time.to(device)
            agent_idx = agent_idx.to(device)
            snr_true = snr_true.to(device)

            snr_pred, mu, logvar = model(image, voltage, time, agent_idx)

            # Use consistent loss (Smooth L1 or any custom one)
            loss = F.smooth_l1_loss(snr_pred.squeeze(), snr_true.squeeze())
            val_loss += loss.item()

            all_preds.extend(snr_pred.squeeze().cpu().numpy())
            all_targets.extend(snr_true.squeeze().cpu().numpy())

    val_loss /= len(val_loader)
    r2 = r2_score(all_targets, all_preds)
    rmse = mean_squared_error(all_targets, all_preds)

    return val_loss, r2, rmse


In [None]:
#################################################################################################
# This method trains a single causal model in Ensemble modelling by calculating the training and 
# validation losses implemented with the Early stopping method 
# if the validation loss starts increasing
#################################################################################################

def train_vae_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=100):
    """
    Trains a Variational Autoencoder (VAE) model using a custom loss function (reconstruction + KL divergence),
    and evaluates it on a validation set after each epoch.

    Implements early stopping based on validation loss.

    Parameters
    ----------
    model : torch.nn.Module
        The VAE model to be trained.

    train_loader : DataLoader
        PyTorch DataLoader for the training dataset.

    val_loader : DataLoader
        PyTorch DataLoader for the validation dataset.

    optimizer : torch.optim.Optimizer
        Optimizer for training (e.g., Adam).

    scheduler : torch.optim.lr_scheduler._LRScheduler
        Learning rate scheduler to step using validation loss.

    device : torch.device
        Device to run training on (e.g., 'cuda' or 'cpu').

    epochs : int, optional
        Maximum number of training epochs. Default is 100.

    Returns
    -------
    model : torch.nn.Module
        The trained model with parameters from the best validation performance.

    Notes
    -----
    - Uses a combined loss: MSE (reconstruction) + KL divergence.
    - Gradients are clipped to `max_norm=1.0` to stabilize training.
    - Early stopping is triggered after `patience` epochs of no validation improvement.
    - Validation metrics include R² score and RMSE, in addition to loss.
    - Logs training and validation metrics for each epoch.
    """

    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        total_loss, total_recon, total_kl = 0.0, 0.0, 0.0

        for image, voltage, time, agent_idx, snr in train_loader:
            image = image.to(device)
            voltage = voltage.to(device)
            time = time.to(device)
            agent_idx = agent_idx.to(device)
            snr_true = snr.to(device)

            optimizer.zero_grad()
            snr_pred, mu, logvar = model(image, voltage, time, agent_idx)

            loss, recon_loss, kl_loss = vae_loss(snr_pred, snr_true, mu, logvar)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()

        avg_loss = total_loss / len(train_loader)

        # --- Validation ---
        val_loss, val_r2, val_rmse = validate_model(model, val_loader, device)
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        #if epoch % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f} | R²: {val_r2:.4f} | RMSE: {val_rmse:.4f}")

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return model


In [None]:
#############################################################################
# The below method trains the ensemble models together
# it calls the train_single_model method with given parameters
#############################################################################

def train_ensemble(num_models=5, seeds=None, device='cpu'):
    """
    Trains an ensemble of VAE models with different random seeds for improved robustness and generalization.

    Each model in the ensemble is:
    - Initialized with a different seed.
    - Trained independently using a unique data split (stratified by SNR).
    - Appended to the returned list for later ensembling or evaluation.

    Parameters
    ----------
    num_models : int, optional
        Number of models to train in the ensemble. Default is 5.
    
    seeds : list of int, optional
        List of seeds to use for each model. If None, generates seeds starting from 42 with a step of 10.
    
    device : str or torch.device, optional
        Device on which to train the models (e.g., 'cpu' or 'cuda'). Default is 'cpu'.

    Returns
    -------
    ensemble_models : list of torch.nn.Module
        List of trained VAE models, each trained with a different seed and dataset split.

    Notes
    -----
    - Uses `train_vae_model()` for training individual models.
    - Applies `set_seed()` for reproducibility across data splits and weight initialization.
    - Each model uses a fresh instance of `CapriCTCausalVAEModel` with `latent_dim=256`.
    - Uses AdamW optimizer and ReduceLROnPlateau scheduler for stability.
    """

    ensemble_models = []

    if seeds is None:
        seeds = [42 + i * 10 for i in range(num_models)]

    for i, seed in enumerate(seeds):
        print(f"\nTraining model {i+1}/{num_models} with seed {seed}")
        set_seed(seed)

        # Load data for the given seed
        train_loader, val_loader = get_data_loaders(seed, batch_size=16)

        # Initialize the model
        # model = CapriCTWithoutAgentModel(latent_dim=256).to(device)
        # model = CapriCTWithoutVoltageModel(latent_dim=256).to(device)
        # model = CapriCTWithoutTimeModel(latent_dim=256).to(device)
        # model = CapriCTWithoutEmbedModel(latent_dim=256).to(device)
        model =CapriCTRobustCheckModel(latent_dim=256).to(device)

        # Optimizer and learning rate scheduler
        optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

        # Train the model
        trained_model = train_vae_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epochs=100
        )

        ensemble_models.append(trained_model)

    return ensemble_models


In [None]:
#############################################################################
# This method uses the trained ensemble models to predict SNR values 
# along with the respective targets
#############################################################################

def get_ensemble_predictions(models, val_loader, device):
    """
    Generates predictions from an ensemble of trained models on a validation set.

    Each model makes independent predictions on the same validation data,
    and all predictions are collected for later aggregation (e.g., mean, median).

    Parameters
    ----------
    models : list of torch.nn.Module
        List of trained VAE models forming the ensemble.
    
    val_loader : DataLoader
        DataLoader for the validation dataset.
    
    device : str or torch.device
        Device to use for inference (e.g., 'cuda' or 'cpu').

    Returns
    -------
    all_preds : list of list of float
        A list containing per-model predictions. Each inner list has predictions of one model 
        for the full validation set.
    
    targets : list of float
        Ground-truth SNR values from the validation set (shared across all models).

    Notes
    -----
    - Assumes all models are compatible with the same input format and produce scalar SNR predictions.
    - Targets are extracted only once from the first pass through the loader.
    - Output can be post-processed (e.g., averaged) to get the final ensemble prediction.
    """
    
    all_preds = []
    targets = []

    for model in models:
        model.eval()
        model_preds = []

        with torch.no_grad():
            for image, voltage_idx, time_idx, agent_idx, snr in val_loader:
                image = image.to(device)
                voltage_idx = voltage_idx.to(device).long()
                time_idx = time_idx.to(device).long()
                agent_idx = agent_idx.to(device).long()

                snr_pred, _, _ = model(image, voltage_idx, time_idx, agent_idx)
                model_preds.extend(snr_pred.squeeze().cpu().numpy())

                # Store targets only once
                if len(targets) < len(val_loader.dataset):
                    targets.extend(snr.squeeze().cpu().numpy())

        all_preds.append(model_preds)

    return all_preds, targets


In [None]:
#################################################################################
# This method evaluates the predicted SNR values with the targets values
#################################################################################

def evaluate_models(all_preds, targets):
    """
    Evaluate individual model predictions and the ensemble.

    Args:
        all_preds (list or np.ndarray or torch.Tensor): Shape [num_models, num_samples]
        targets (list or np.ndarray or torch.Tensor): Shape [num_samples]

    Returns:
        dict: {
            'individual_metrics': List[dict],
            'best_model_index': int,
            'ensemble_metrics': dict
        }
    """
    # Convert inputs to tensors
    if isinstance(all_preds, np.ndarray):
        all_preds = torch.from_numpy(all_preds)
    elif isinstance(all_preds, list):
        all_preds = torch.stack([torch.tensor(p) for p in all_preds])
    elif not isinstance(all_preds, torch.Tensor):
        raise TypeError("all_preds must be a list, np.ndarray, or torch.Tensor")

    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    elif isinstance(targets, list):
        targets = torch.tensor(targets)
    elif not isinstance(targets, torch.Tensor):
        raise TypeError("targets must be a list, np.ndarray, or torch.Tensor")

    targets = targets.squeeze()
    num_models = all_preds.shape[0]

    individual_metrics = []
    best_r2 = -np.inf
    best_model_idx = -1

    # Evaluate each model
    for i in range(num_models):
        preds = all_preds[i].squeeze().cpu().numpy()
        targs = targets.cpu().numpy()

        r2 = r2_score(targs, preds)
        rmse = np.sqrt(mean_squared_error(targs, preds))
        mae = mean_absolute_error(targs, preds)

        individual_metrics.append({'model_idx': i, 'r2': r2, 'rmse': rmse, 'mae': mae})

        if r2 > best_r2:
            best_r2 = r2
            best_model_idx = i

    # Ensemble (mean of all predictions)
    mean_preds = all_preds.mean(dim=0).squeeze().cpu().numpy()
    std_preds = all_preds.std(dim=0, unbiased=False).squeeze().cpu().numpy()
    targets_np = targets.cpu().numpy()

    r2_ens = r2_score(targets_np, mean_preds)
    rmse_ens = np.sqrt(mean_squared_error(targets_np, mean_preds))
    mae_ens = mean_absolute_error(targets_np, mean_preds)

    ensemble_metrics = {
        'r2': r2_ens,
        'rmse': rmse_ens,
        'mae': mae_ens,
        'mean_preds': mean_preds,
        'std_preds': std_preds,
        'targets': targets_np
    }

    return {
        'individual_metrics': individual_metrics,
        'best_model_index': best_model_idx,
        'ensemble_metrics': ensemble_metrics
    }


In [None]:
###########################################################################
# Initializing the dataset for Intervention and counterfactual inference
###########################################################################

def get_dataset():
    """
    Loads and returns the CapriCTDataset with predefined image transformations.

    Applies a consistent transform to resize CT images and convert them to tensors.
    Assumes the dataset CSV and image directory are located under a relative `../dataset/` path.

    Returns
    -------
    dataset : CapriCTDataset
        An instance of the custom dataset with image and metadata fields prepared
        for model training or inference.

    Notes
    -----
    - Images are resized to 9x9 pixels and converted to tensors.
    - Paths:
        - Metadata CSV: ../dataset/final_dataset.csv
        - Image folder: ../dataset/img/
    """
    
    transform = transforms.Compose([
    transforms.Resize((9, 9)),  
    transforms.ToTensor()
    ])
    base_path = Path("../dataset")

    dataset = CapriCTDataset(
        metadata_csv= base_path / "final_dataset.csv",
        img_folder_path= base_path / "img" ,
        transform=transform
    )

    return dataset

In [None]:
##########################################
# Setting the device parameter
##########################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##########################################
# Train ensemble models without Agent
##########################################
models = train_ensemble(num_models=5, device=device)




Training model 1/5 with seed 42
Epoch 1/100 | Train Loss: 79485.0954 | Val Loss: 150.6637 | R²: -0.1132 | RMSE: 62937.3320
Epoch 2/100 | Train Loss: 84095.8979 | Val Loss: 150.8500 | R²: -0.0994 | RMSE: 62152.8242
Epoch 3/100 | Train Loss: 83578.2689 | Val Loss: 150.0239 | R²: -0.1035 | RMSE: 62385.9492
Epoch 4/100 | Train Loss: 83358.7015 | Val Loss: 150.8885 | R²: -0.0901 | RMSE: 61629.2734
Epoch 5/100 | Train Loss: 81290.0173 | Val Loss: 153.6316 | R²: -0.0568 | RMSE: 59746.0195
Epoch 6/100 | Train Loss: 87824.4175 | Val Loss: 153.0011 | R²: -0.0606 | RMSE: 59961.5078
Epoch 7/100 | Train Loss: 76196.3481 | Val Loss: 156.7072 | R²: -0.0294 | RMSE: 58195.1406
Epoch 8/100 | Train Loss: 78213.4565 | Val Loss: 159.3780 | R²: -0.0275 | RMSE: 58088.1172
Epoch 9/100 | Train Loss: 75934.6944 | Val Loss: 159.1071 | R²: -0.0044 | RMSE: 56784.6797
Epoch 10/100 | Train Loss: 79389.4736 | Val Loss: 162.9792 | R²: 0.0025 | RMSE: 56392.2188
Epoch 11/100 | Train Loss: 72835.8719 | Val Loss: 159.900

In [None]:
################################################
# Get predictions and targets
################################################

_, val_loader = get_data_loaders(seed=42, batch_size=16)
all_preds, targets = get_ensemble_predictions(models, val_loader, device)

In [None]:
#####################################################################
# Evaluate the ensemble models for our CAPRI-CT without agent
#####################################################################

result = evaluate_models(all_preds, targets)

# View metrics
print(f"***************Individual Model Metrics********************************")
for m in result['individual_metrics']:
    print(f"Model {m['model_idx']+1} → R2: {m['r2']:.3f}, RMSE: {m['rmse']:.3f}, MAE: {m['mae']:.3f}")
print(f"***********************************************************************")

print(f"\nBest Model : Model {result['best_model_index']+1}")
print(f"***********************************************************************")
print(f"Ensemble R2: {result['ensemble_metrics']['r2']:.3f}")
print(f"Ensemble RMSE: {result['ensemble_metrics']['rmse']:.3f}")
print(f"Ensemble MAE: {result['ensemble_metrics']['mae']:.3f}")
print(f"***********************************************************************")


***************Individual Model Metrics********************************
Model 1 → R2: 0.007, RMSE: 236.926, MAE: 164.079
Model 2 → R2: -0.486, RMSE: 289.824, MAE: 170.957
Model 3 → R2: -0.019, RMSE: 240.078, MAE: 165.522
Model 4 → R2: 0.021, RMSE: 235.202, MAE: 160.408
Model 5 → R2: 0.032, RMSE: 233.931, MAE: 162.385
***********************************************************************

Best Model : Model 5
***********************************************************************
Ensemble R2: 0.005
Ensemble RMSE: 237.137
Ensemble MAE: 164.256
***********************************************************************


In [None]:
##########################################
# Train ensemble models without Voltage
##########################################

models_v = train_ensemble(num_models=5, device=device)


Training model 1/5 with seed 42
Epoch 1/100 | Train Loss: 81786.7895 | Val Loss: 145.3927 | R²: -0.0640 | RMSE: 60153.8086
Epoch 2/100 | Train Loss: 81699.3956 | Val Loss: 143.5181 | R²: -0.0409 | RMSE: 58848.9844
Epoch 3/100 | Train Loss: 74642.0978 | Val Loss: 135.1902 | R²: 0.0794 | RMSE: 52046.7891
Epoch 4/100 | Train Loss: 71125.3283 | Val Loss: 132.1728 | R²: 0.1454 | RMSE: 48316.2891
Epoch 5/100 | Train Loss: 65198.5932 | Val Loss: 125.8901 | R²: 0.2104 | RMSE: 44638.6445
Epoch 6/100 | Train Loss: 62667.3133 | Val Loss: 138.2948 | R²: 0.1105 | RMSE: 50285.5273
Epoch 7/100 | Train Loss: 49364.3505 | Val Loss: 115.1842 | R²: 0.3554 | RMSE: 36441.8047
Epoch 8/100 | Train Loss: 46859.2113 | Val Loss: 116.6540 | R²: 0.3783 | RMSE: 35147.9727
Epoch 9/100 | Train Loss: 40591.7773 | Val Loss: 135.6205 | R²: 0.2360 | RMSE: 43193.2891
Epoch 10/100 | Train Loss: 38984.4558 | Val Loss: 104.2064 | R²: 0.5259 | RMSE: 26803.6270
Epoch 11/100 | Train Loss: 34568.9005 | Val Loss: 120.8455 | R²:

In [None]:
################################################
# Get predictions and targets
################################################

_, val_loader_v = get_data_loaders(seed=42, batch_size=16)
all_preds_v, targets_v = get_ensemble_predictions(models_v, val_loader_v, device)

In [None]:
#####################################################################
# Evaluate the ensemble models for our CAPRI-CT without voltage
#####################################################################

resultv = evaluate_models(all_preds_v, targets_v)

# View metrics
print(f"***************Individual Model Metrics********************************")
for m in resultv['individual_metrics']:
    print(f"Model {m['model_idx']+1} → R2: {m['r2']:.3f}, RMSE: {m['rmse']:.3f}, MAE: {m['mae']:.3f}")
print(f"***********************************************************************")

print(f"\nBest Model : Model {resultv['best_model_index']+1}")
print(f"***********************************************************************")
print(f"Ensemble R2: {resultv['ensemble_metrics']['r2']:.3f}")
print(f"Ensemble RMSE: {resultv['ensemble_metrics']['rmse']:.3f}")
print(f"Ensemble MAE: {resultv['ensemble_metrics']['mae']:.3f}")
print(f"***********************************************************************")


***************Individual Model Metrics********************************
Model 1 → R2: 0.656, RMSE: 139.511, MAE: 93.568
Model 2 → R2: 0.645, RMSE: 141.753, MAE: 92.991
Model 3 → R2: 0.662, RMSE: 138.286, MAE: 91.725
Model 4 → R2: 0.653, RMSE: 140.055, MAE: 94.373
Model 5 → R2: 0.549, RMSE: 159.700, MAE: 106.748
***********************************************************************

Best Model : Model 3
***********************************************************************
Ensemble R2: 0.655
Ensemble RMSE: 139.720
Ensemble MAE: 91.421
***********************************************************************


In [None]:
##########################################
# Train ensemble models without time
##########################################

models_t = train_ensemble(num_models=5, device=device)


Training model 1/5 with seed 42
Epoch 1/100 | Train Loss: 82066.4489 | Val Loss: 144.2555 | R²: -0.0414 | RMSE: 58874.7500
Epoch 2/100 | Train Loss: 78887.7163 | Val Loss: 141.6847 | R²: -0.0122 | RMSE: 57227.5469
Epoch 3/100 | Train Loss: 72884.5531 | Val Loss: 144.4774 | R²: -0.0373 | RMSE: 58642.7383
Epoch 4/100 | Train Loss: 70664.4004 | Val Loss: 135.4380 | R²: 0.0825 | RMSE: 51871.7969
Epoch 5/100 | Train Loss: 61308.8272 | Val Loss: 128.4142 | R²: 0.1884 | RMSE: 45885.8789
Epoch 6/100 | Train Loss: 61332.8769 | Val Loss: 145.3295 | R²: 0.0552 | RMSE: 53412.3906
Epoch 7/100 | Train Loss: 49873.8410 | Val Loss: 112.9105 | R²: 0.4131 | RMSE: 33182.0508
Epoch 8/100 | Train Loss: 44315.7408 | Val Loss: 123.3406 | R²: 0.3355 | RMSE: 37568.9102
Epoch 9/100 | Train Loss: 37972.0620 | Val Loss: 109.0942 | R²: 0.4753 | RMSE: 29665.8906
Epoch 10/100 | Train Loss: 35385.0142 | Val Loss: 101.6780 | R²: 0.5902 | RMSE: 23170.0625
Epoch 11/100 | Train Loss: 31936.4215 | Val Loss: 101.1761 | R²

In [None]:
################################################
# Get predictions and targets
################################################
_, val_loader_t = get_data_loaders(seed=42, batch_size=16)
all_preds_t, targets_t = get_ensemble_predictions(models_t, val_loader_t, device)

In [None]:
#####################################################################
# Evaluate the ensemble models for our CAPRI-CT without time
#####################################################################

resultt = evaluate_models(all_preds_t, targets_t)

# View metrics
print(f"***************Individual Model Metrics********************************")
for m in resultt['individual_metrics']:
    print(f"Model {m['model_idx']+1} → R2: {m['r2']:.3f}, RMSE: {m['rmse']:.3f}, MAE: {m['mae']:.3f}")
print(f"***********************************************************************")

print(f"\nBest Model : Model {resultt['best_model_index']+1}")
print(f"***********************************************************************")
print(f"Ensemble R2: {resultt['ensemble_metrics']['r2']:.3f}")
print(f"Ensemble RMSE: {resultt['ensemble_metrics']['rmse']:.3f}")
print(f"Ensemble MAE: {resultt['ensemble_metrics']['mae']:.3f}")
print(f"***********************************************************************")


***************Individual Model Metrics********************************
Model 1 → R2: 0.536, RMSE: 162.020, MAE: 106.878
Model 2 → R2: 0.686, RMSE: 133.207, MAE: 88.050
Model 3 → R2: 0.685, RMSE: 133.491, MAE: 87.294
Model 4 → R2: 0.692, RMSE: 132.015, MAE: 87.603
Model 5 → R2: 0.697, RMSE: 130.843, MAE: 86.152
***********************************************************************

Best Model : Model 5
***********************************************************************
Ensemble R2: 0.684
Ensemble RMSE: 133.561
Ensemble MAE: 87.209
***********************************************************************


In [None]:
##########################################
# Train ensemble models without embeddings
##########################################

models_e = train_ensemble(num_models=5, device=device)


Training model 1/5 with seed 42
Epoch 1/100 | Train Loss: 83281.6300 | Val Loss: 150.0876 | R²: -0.1295 | RMSE: 63856.8398
Epoch 2/100 | Train Loss: 88608.0228 | Val Loss: 150.5495 | R²: -0.1151 | RMSE: 63044.9141
Epoch 3/100 | Train Loss: 80431.7794 | Val Loss: 150.4671 | R²: -0.1075 | RMSE: 62612.7695
Epoch 4/100 | Train Loss: 88540.6658 | Val Loss: 151.4975 | R²: -0.0783 | RMSE: 60961.1875
Epoch 5/100 | Train Loss: 83023.4114 | Val Loss: 152.9523 | R²: -0.0520 | RMSE: 59476.3945
Epoch 6/100 | Train Loss: 82167.3363 | Val Loss: 154.4995 | R²: -0.0444 | RMSE: 59047.3242
Epoch 7/100 | Train Loss: 83034.7643 | Val Loss: 154.2074 | R²: -0.0396 | RMSE: 58776.6367
Epoch 8/100 | Train Loss: 80734.0245 | Val Loss: 158.7559 | R²: -0.0073 | RMSE: 56949.9727
Epoch 9/100 | Train Loss: 78076.3537 | Val Loss: 164.0019 | R²: 0.0097 | RMSE: 55988.1602
Epoch 10/100 | Train Loss: 78460.5496 | Val Loss: 159.6242 | R²: -0.0007 | RMSE: 56573.7031
Epoch 11/100 | Train Loss: 80916.2811 | Val Loss: 162.479

In [None]:
################################################
# Get predictions and targets
################################################

_, val_loader_e = get_data_loaders(seed=42, batch_size=16)
all_preds_e, targets_e = get_ensemble_predictions(models_e, val_loader_e, device)

In [None]:
#####################################################################
# Evaluate the ensemble models for our CAPRI-CT without embeddings
#####################################################################

resulte = evaluate_models(all_preds_e, targets_e)

# View metrics
print(f"***************Individual Model Metrics********************************")
for m in resulte['individual_metrics']:
    print(f"Model {m['model_idx']+1} → R2: {m['r2']:.3f}, RMSE: {m['rmse']:.3f}, MAE: {m['mae']:.3f}")
print(f"***********************************************************************")

print(f"\nBest Model : Model {resulte['best_model_index']+1}")
print(f"***********************************************************************")
print(f"Ensemble R2: {resulte['ensemble_metrics']['r2']:.3f}")
print(f"Ensemble RMSE: {resulte['ensemble_metrics']['rmse']:.3f}")
print(f"Ensemble MAE: {resulte['ensemble_metrics']['mae']:.3f}")
print(f"***********************************************************************")


***************Individual Model Metrics********************************
Model 1 → R2: 0.006, RMSE: 237.097, MAE: 163.837
Model 2 → R2: 0.016, RMSE: 235.908, MAE: 166.959
Model 3 → R2: 0.012, RMSE: 236.353, MAE: 164.456
Model 4 → R2: -0.161, RMSE: 256.172, MAE: 220.691
Model 5 → R2: 0.003, RMSE: 237.374, MAE: 160.566
***********************************************************************

Best Model : Model 2
***********************************************************************
Ensemble R2: 0.021
Ensemble RMSE: 235.216
Ensemble MAE: 172.722
***********************************************************************


In [None]:
##########################################
# Train ensemble models with added noise
##########################################

models_r = train_ensemble(num_models=5, device=device)


Training model 1/5 with seed 42
Epoch 1/100 | Train Loss: 84864.0635 | Val Loss: 145.4918 | R²: -0.0662 | RMSE: 60275.0703
Epoch 2/100 | Train Loss: 78565.0575 | Val Loss: 143.2604 | R²: -0.0278 | RMSE: 58108.3125
Epoch 3/100 | Train Loss: 77237.8451 | Val Loss: 138.8447 | R²: 0.0409 | RMSE: 54222.1953
Epoch 4/100 | Train Loss: 76129.1559 | Val Loss: 132.6336 | R²: 0.1370 | RMSE: 48790.2109
Epoch 5/100 | Train Loss: 66365.2471 | Val Loss: 126.6488 | R²: 0.2203 | RMSE: 44082.2031
Epoch 6/100 | Train Loss: 55886.1469 | Val Loss: 117.9415 | R²: 0.3399 | RMSE: 37318.7656
Epoch 7/100 | Train Loss: 55338.3519 | Val Loss: 116.4466 | R²: 0.3636 | RMSE: 35979.0781
Epoch 8/100 | Train Loss: 44526.1115 | Val Loss: 106.1938 | R²: 0.5208 | RMSE: 27089.4355
Epoch 9/100 | Train Loss: 38197.5289 | Val Loss: 97.4817 | R²: 0.5810 | RMSE: 23689.0469
Epoch 10/100 | Train Loss: 35930.1165 | Val Loss: 101.6316 | R²: 0.5523 | RMSE: 25313.4863
Epoch 11/100 | Train Loss: 30967.9114 | Val Loss: 139.1190 | R²: 

In [None]:
################################################
# Get predictions and targets
################################################
_, val_loader_r = get_data_loaders(seed=42, batch_size=16)
all_preds_r, targets_r = get_ensemble_predictions(models_r, val_loader_r, device)

In [None]:
#####################################################################
# Evaluate the ensemble models for our CAPRI-CT with added noise
#####################################################################

resultr = evaluate_models(all_preds_r, targets_r)

# View metrics
print(f"***************Individual Model Metrics********************************")
for m in resultr['individual_metrics']:
    print(f"Model {m['model_idx']+1} → R2: {m['r2']:.3f}, RMSE: {m['rmse']:.3f}, MAE: {m['mae']:.3f}")
print(f"***********************************************************************")

print(f"\nBest Model : Model {resultr['best_model_index']+1}")
print(f"***********************************************************************")
print(f"Ensemble R2: {resultr['ensemble_metrics']['r2']:.3f}")
print(f"Ensemble RMSE: {resultr['ensemble_metrics']['rmse']:.3f}")
print(f"Ensemble MAE: {resultr['ensemble_metrics']['mae']:.3f}")
print(f"***********************************************************************")


***************Individual Model Metrics********************************
Model 1 → R2: 0.783, RMSE: 110.673, MAE: 72.906
Model 2 → R2: 0.798, RMSE: 106.981, MAE: 68.327
Model 3 → R2: 0.794, RMSE: 108.044, MAE: 69.872
Model 4 → R2: 0.790, RMSE: 108.855, MAE: 70.707
Model 5 → R2: 0.752, RMSE: 118.435, MAE: 76.341
***********************************************************************

Best Model : Model 2
***********************************************************************
Ensemble R2: 0.797
Ensemble RMSE: 107.192
Ensemble MAE: 68.704
***********************************************************************


### Table: Performance of different CAPRI-CT model variants

| **Versions of CAPRI-CT model**                          | **MAE**   | **RMSE**   | **R²**   |
|----------------------------------------------------------|-----------|------------|----------|
| Image (*i*) + metadata (*v, t, a*)                       | 68.028    | 106.493    | 0.799    |
| Image (*i*) + metadata (*v, t, a, noise*)                | 68.704    | 107.192    | 0.797    |
| Image (*i*) + metadata (*v, a*)                          | 87.209    | 133.561    | 0.684    |
| Image (*i*) + metadata (*t, a*)                          | 91.421    | 139.720    | 0.655    |
| Image (*i*) only                                          | 172.722   | 235.216    | 0.021    |
| Image (*i*) + metadata (*v, t*)                          | 164.256   | 237.137    | 0.005    |


**Ablation Study and Causal Perturbation Analysis:**  
The above table reports the results of causal structure perturbation through ablation studies, evaluating the impact of removing individual input variables on the model's performance. The full CAPRI-CT model achieves the best predictive accuracy (MAE: 68.03, RMSE: 106.49, R²: 0.799). Ablating current (*t*) and voltage (*v*) led to moderate degradation in performance, with R² decreasing to 0.684 and 0.655, respectively. This indicates their contribution to the prediction task but suggests limited causal influence relative to other variables.

By contrast, removing the contrast agent (*a*) resulted in a substantial drop in performance (R²: 0.005), nearly equivalent to removing all three inputs (R²: 0.021). To further test model robustness, we introduced an additional noise variable to each input parameter; performance remained stable (MAE: 68.70, RMSE: 107.19, R²: 0.797). This highlights the contrast agent as a dominant causal factor for SNR in CT imaging. These results support the causal assumptions embedded in the model and demonstrate its sensitivity to disruptions in key parent nodes of the causal graph.
