## CNN Baseline Non Causal Model

The non-causal baseline model is used as a reference point to compare against the causal-aware Capri-CT model. Unlike Capri-CT, which integrates causal reasoning to understand how interventions affect outcomes, the baseline model relies solely on correlational patterns in the data without explicitly modeling causal relationships.

This baseline typically consists of a convolutional neural network (CNN) that processes image data alongside metadata inputs such as voltage, time, and agent type. It predicts outcomes like Signal-to-Noise Ratio (SNR) based on observed features, without accounting for causal interventions.

While effective at capturing associations, the non-causal baseline lacks robustness to changes caused by interventions or shifts in the data distribution. Therefore, it provides a meaningful benchmark to demonstrate the advantages of causal-aware models like Capri-CT in terms of interpretability, generalization, and handling of counterfactual scenarios.


In [1]:
######################################################################################
# Importing the required libraries for our CNN Baseline Non - Causal Model
######################################################################################

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import os
from PIL import Image
import random
import time as T
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from pathlib import Path
from sklearn.metrics import mean_absolute_error,mean_squared_error, r2_score


In [2]:
##########################################################################################
# Below is the Class CTDataset 
# Combining the CT image with the metadata using Dataset package
# for SNR prediction
##########################################################################################

class CTDataset(Dataset):
    """
    PyTorch Dataset for loading CT scan images and associated metadata.

    Args:
        metadata_csv (str or Path): Path to the CSV file containing metadata.
        img_folder_path (str or Path): Directory containing CT scan image files.
        transform (callable, optional): Transformations to apply to the images.

    Attributes:
        img_data (pd.DataFrame): DataFrame containing the metadata.
        img_folder (Path): Path to the image folder.
        transform (callable or None): Optional transform to apply to images.

    Methods:
        __getitem__(idx): Returns a single data sample consisting of:
            - transformed image tensor (grayscale),
            - one-hot encoded agent vector (tensor),
            - voltage (tensor),
            - time (tensor),
            - CNR (tensor),
            - SNR (tensor).
        __len__(): Returns the total number of samples.
    """

    def __init__(self, metadata_csv, img_folder_path, transform=None):
        self.img_data = pd.read_csv(metadata_csv)
        self.img_folder = img_folder_path
        self.transform = transform

    def __getitem__(self, idx):
        row = self.img_data.iloc[idx]
        img = Image.open(os.path.join(self.img_folder, row['Filename'])).convert('L')
        if self.transform:
            image = self.transform(img)

        agent_dict = {'Iodine': 0, 'BiNPs 50nm': 1, 'BiNPs 100nm': 2}
        agent_vector = torch.zeros(len(agent_dict))
        agent_vector[agent_dict[row['Classification']]] = 1

        voltage = torch.tensor([row['Voltage']], dtype=torch.float32)
        time = torch.tensor([row['Time']], dtype=torch.float32)
        cnr = torch.tensor([row['CNR']], dtype=torch.float32)
        snr = torch.tensor([row['SNR']], dtype=torch.float32)

        return image, agent_vector, voltage, time, cnr, snr

    def __len__(self):
        return len(self.img_data)

In [3]:
#################################################################################
# Below is our  CNN(Convolutional Neural Network) Baseline Non - Causal Model
#################################################################################

class CNNBaselineModel(nn.Module):
    """
    Convolutional Neural Network baseline model for predicting SNR from CT images and metadata.

    Args:
        image_channels (int): Number of input image channels. Default is 1 (grayscale).
        input_dim (int): Dimension of metadata input vector (e.g., voltage, time, agent one-hot). Default is 5.

    Architecture:
        - Three convolutional layers with ReLU activations and max pooling.
        - Adaptive average pooling to reduce spatial dimensions to (1, 1).
        - Fully connected layers combining flattened CNN features and metadata.
        - Two separate output heads for multi-task learning:
            * SNR prediction (main task)
            * CNR prediction (auxiliary task)

    Methods:
        forward(image, agent_vector, voltage, time):
            Performs a forward pass through the model.
            
            Args:
                image (Tensor): Input image tensor of shape (B, C, H, W).
                agent_vector (Tensor): One-hot encoded agent vector of shape (B, N).
                voltage (Tensor): Voltage feature tensor of shape (B, 1).
                time (Tensor): Time feature tensor of shape (B, 1).
            
            Returns:
                snr_out (Tensor): Predicted SNR values (B, 1).
                cnr_out (Tensor): Predicted CNR values (B, 1).
    """
    
    def __init__(self, image_channels=1, input_dim=5):  
        super(CNNBaselineModel, self).__init__()

        # CNN layers
        self.conv1 = nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # (B, 128, 1, 1)

        # Fully connected layers after flattening CNN features + metadata
        self.fc1 = nn.Linear(128 + input_dim, 256)
        self.fc2 = nn.Linear(256, 128)

        # Two separate heads for multi-task output
        self.fc_snr = nn.Linear(128, 1)  
        self.fc_cnr = nn.Linear(128, 1) 

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, image, agent_vector, voltage, time):
        x = F.relu(self.conv1(image))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))

        x = self.adaptive_pool(x)  # (B, 128, 1, 1)
        x = torch.flatten(x, 1)    # (B, 128)

        # Metadata 
        meta_input = torch.cat([voltage, time, agent_vector], dim=1)  # (B, 5)

        combined = torch.cat([x, meta_input], dim=1)  # (B, 128 + 5)

        x = F.relu(self.fc1(combined))
        x = F.relu(self.fc2(x))

        # Multi-task outputs
        snr_out = self.fc_snr(x)
        cnr_out = self.fc_cnr(x)

        return snr_out, cnr_out


In [4]:
#################################################################################
# This method trains the baseline CNN model with the given model hyperparameters
#################################################################################

def train_model(model, train_loader, val_loader, num_epochs=150, lr=1e-3, patience=3):
    """
    Train the given model using training and validation data loaders with early stopping.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_loader (DataLoader): DataLoader for the training dataset.
        val_loader (DataLoader): DataLoader for the validation dataset.
        num_epochs (int, optional): Maximum number of training epochs. Default is 150.
        lr (float, optional): Learning rate for the optimizer. Default is 1e-3.
        patience (int, optional): Number of epochs with no improvement on validation loss before stopping early. Default is 3.

    Returns:
        float: Total training time in seconds.

    Description:
        - Uses Smooth L1 loss for both SNR and CNR predictions.
        - Optimizes with AdamW optimizer with weight decay.
        - Applies cosine annealing learning rate scheduler.
        - Combines losses with a weighting factor alpha (default 0.3)
        - Clips gradients to a maximum norm of 1.0 to stabilize training.
        - Implements early stopping based on validation loss not improving for 'patience' epochs.
    """
    
    criterion = nn.SmoothL1Loss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_val_loss = float('inf')
    patience_counter = 0
    start_time = T.time()
    alpha = 0.3
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for image, agent_vector, voltage, time, cnr, snr in train_loader:
            optimizer.zero_grad()
            snr_pred, cnr_pred = model(image, agent_vector, voltage, time)

            loss_snr = criterion(snr_pred, snr)
            loss_cnr = criterion(cnr_pred, cnr)

            loss = loss_snr + alpha * loss_cnr
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for image, agent_vector, voltage, time, cnr, snr in val_loader:
                snr_pred, cnr_pred = model(image, agent_vector, voltage, time)

                loss_snr = criterion(snr_pred, snr)
                loss_cnr = criterion(cnr_pred, cnr)

                loss = loss_snr + alpha * loss_cnr
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"patience_counter : {patience_counter}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
    end_time = T.time()
    total_time = end_time - start_time
    print(f"⏱️ Training time for this model: {total_time:.2f} seconds")
    return total_time

In [5]:
###############################################################
# Setting the seed value for each training loop
###############################################################

def set_seed(seed):
    """
    Sets the random seed across Python, NumPy, and PyTorch (CPU and GPU) for reproducibility.

    Parameters
    ----------
    seed : int
        The seed value to ensure deterministic behavior across runs.
    """
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
############################################################################################
# get_data_loaders function:
# Loads a CT dataset with images and metadata.
# Splits the dataset into training and validation sets.
# Applies image transformations.
############################################################################################

def get_data_loaders(seed):
    """
    Prepare and return training and validation DataLoaders for the CTDataset.

    Args:
        seed (int): Random seed for reproducibility of dataset splitting and shuffling.

    Returns:
        train_loader (DataLoader): DataLoader for the training dataset subset.
        val_loader (DataLoader): DataLoader for the validation dataset subset.

    Description:
        - Sets manual seed for PyTorch and DataLoader generator to ensure reproducibility.
        - Applies image transformations: resizing to 9x9 and conversion to tensor.
        - Loads dataset from specified CSV metadata and image folder.
        - Splits dataset into training (80%) and validation (20%) subsets using a fixed random seed.
        - Creates DataLoaders with batch size 8; training loader is shuffled, validation loader is not.
    """
    torch.manual_seed(seed)  
    generator = torch.Generator().manual_seed(seed)
    
    transform = transforms.Compose([
    transforms.Resize((9, 9)),  
    transforms.ToTensor()
    ])

    base_path = Path("../dataset")
    
    
    dataset = CTDataset(
        metadata_csv= base_path / "final_dataset.csv",
        img_folder_path= base_path / "img" ,
        transform=transform
    )

    # Split
    train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=seed)
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)

    # DataLoaders
    train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, generator=generator)
    val_loader = DataLoader(val_subset, batch_size=8, shuffle=False)

    return train_loader, val_loader

In [7]:
#############################################################################
# The below method trains the ensemble models together
# it calls the train_single_model method with given parameters
#############################################################################

base_times = []
def train_base_model(num_models=5, base_seed=42):
    """
    Train multiple CNN baseline models with different seeds and return the trained models.

    Args:
        num_models (int): Number of models to train. Defaults to 5.
        base_seed (int): Starting seed for reproducibility. Each model uses base_seed + i.

    Returns:
        base_models_list (list): List of trained CNNBaselineModel instances.

    Description:
        - Iterates over num_models, setting a unique seed for each.
        - For each seed, initializes a CNNBaselineModel and moves it to CPU.
        - Loads training and validation data loaders with the current seed.
        - Trains the model and records training time.
        - Collects all trained models in a list and returns it.
    """
    base_models_list = []
    device = "cpu"
    for i in range(num_models):
        seed = base_seed + i
        print(f"\n🔁 Training model {i+1}/{num_models} with seed {seed}")
        set_seed(seed)

        image_channels = 1 
        base_model = CNNBaselineModel(image_channels=image_channels, input_dim=5).to(device)
        
        train_loader, val_loader = get_data_loaders(seed)

        train_time = train_model(base_model, train_loader, val_loader)
        base_times.append(train_time)
        base_models_list.append(base_model)

    return base_models_list

In [8]:
#################################################################################
# This method evaluates the predicted SNR values with the targets values
#################################################################################

def evaluate_models(ensemble_models, seed, device='cpu'):
    """
    Evaluate an ensemble of models on the validation set and compute mean and std of predictions.

    Args:
        ensemble_models (list): List of trained models to evaluate.
        seed (int): Seed for data loader reproducibility.
        device (str): Device to run the evaluation on ('cpu' or 'cuda'). Defaults to 'cpu'.

    Returns:
        preds_mean (np.ndarray): Mean predictions of the ensemble for each sample.
        preds_std (np.ndarray): Standard deviation of the ensemble predictions for each sample.
        targets (np.ndarray): Ground truth SNR values from the validation set.

    Description:
        - Loads validation data using the given seed.
        - For each batch, moves data to the specified device.
        - Collects predictions from each model in the ensemble without gradient tracking.
        - Computes the mean and standard deviation of predictions across the ensemble.
        - Aggregates predictions and ground truth targets for all validation samples.
    """
    
    train_loader, val_loader = get_data_loaders(seed)

    preds_mean, preds_std, targets = [], [], []

    for image, agent_vector, voltage, time,cnr,snr in val_loader:
        image = image.to(device)
        agent_vector = agent_vector.to(device)
        voltage = voltage.to(device)
        time = time.to(device)
        cnr = cnr.to(device)
        snr = snr.to(device)

        batch_preds = []
        with torch.no_grad():
            for model in ensemble_models:
                model.eval()
                model.to(device)
                output,_ = model(image, agent_vector, voltage, time)
                batch_preds.append(output.cpu())

        batch_preds = torch.stack(batch_preds)  # [num_models, B, 1]
        mean_pred = batch_preds.mean(dim=0).squeeze().numpy()     # [B, 1]
        std_pred = batch_preds.std(dim=0).squeeze().numpy()       # [B, 1]

        preds_mean.extend(mean_pred)
        preds_std.extend(std_pred)
        targets.extend(snr.cpu().numpy())

    preds_mean = np.array(preds_mean)
    preds_std = np.array(preds_std)
    targets = np.array(targets)

    return preds_mean, preds_std, targets


In [9]:
#################################################################################
# This method evaluates the predicted SNR values with the targets values
# of Individual models
#################################################################################

def evaluate_individual_models(ensemble_models, seed, device='cpu'):
    """
    Evaluate each model in an ensemble individually on the validation set and compute regression metrics.

    Args:
        ensemble_models (list): List of trained models to evaluate.
        seed (int): Seed for data loader reproducibility.
        device (str): Device to run the evaluation on ('cpu' or 'cuda'). Defaults to 'cpu'.

    Returns:
        model_metrics (list of dicts): List containing metrics (MAE, RMSE, R2) for each model.

    Description:
        - Loads validation data using the given seed.
        - For each model, moves it to the specified device and sets to evaluation mode.
        - Iterates over validation batches, making predictions without gradient computation.
        - Collects predictions and ground truth targets.
        - Calculates Mean Absolute Error (MAE), Root Mean Squared Error (RMSE), and R² score for each model.
        - Prints the metrics for each model and returns a list summarizing all metrics.
    """

    _, val_loader = get_data_loaders(seed)
    
    model_metrics = []

    for model_index, model in enumerate(ensemble_models):
        model.eval()
        model.to(device)

        all_preds = []
        all_targets = []

        with torch.no_grad():
            for image, agent_vector, voltage, time, _, snr in val_loader:
                image = image.to(device)
                agent_vector = agent_vector.to(device)
                voltage = voltage.to(device)
                time = time.to(device)
                snr = snr.to(device)

                pred, _ = model(image, agent_vector, voltage, time)
                all_preds.extend(pred.squeeze().cpu().numpy())
                all_targets.extend(snr.squeeze().cpu().numpy())

        # Convert to NumPy arrays
        all_preds = np.array(all_preds)
        all_targets = np.array(all_targets)

        # Compute metrics
        mae = mean_absolute_error(all_targets, all_preds)
        rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
        r2 = r2_score(all_targets, all_preds)

        model_metrics.append({
            'Model': model_index + 1,
            'MAE': mae,
            'RMSE': rmse,
            'R2': r2
        })

        print(f"Model {model_index+1} | MAE: {mae:.3f}, RMSE: {rmse:.3f}, R2: {r2:.3f}")

    return model_metrics


In [10]:
##########################################
# Train ensemble baseline models
##########################################

base_models_list = train_base_model(num_models=5)


🔁 Training model 1/5 with seed 42
Epoch 1/150, Train Loss: 159.2284, Validation Loss: 170.3130
Epoch 2/150, Train Loss: 157.1248, Validation Loss: 170.4923
patience_counter : 1
Epoch 3/150, Train Loss: 155.6950, Validation Loss: 167.2182
Epoch 4/150, Train Loss: 151.9552, Validation Loss: 160.7836
Epoch 5/150, Train Loss: 145.9545, Validation Loss: 152.8650
Epoch 6/150, Train Loss: 140.9492, Validation Loss: 173.6997
patience_counter : 1
Epoch 7/150, Train Loss: 132.5173, Validation Loss: 138.0866
Epoch 8/150, Train Loss: 126.2785, Validation Loss: 131.3831
Epoch 9/150, Train Loss: 124.1043, Validation Loss: 128.1587
Epoch 10/150, Train Loss: 115.2055, Validation Loss: 114.3625
Epoch 11/150, Train Loss: 109.7634, Validation Loss: 127.5375
patience_counter : 1
Epoch 12/150, Train Loss: 107.3180, Validation Loss: 112.2695
Epoch 13/150, Train Loss: 105.7968, Validation Loss: 109.7242
Epoch 14/150, Train Loss: 105.2045, Validation Loss: 104.7327
Epoch 15/150, Train Loss: 105.1341, Validat

In [11]:
#####################################################################
# Evaluate the individual models for CNN baseline model
#####################################################################

ind_metric = evaluate_individual_models(ensemble_models=base_models_list,seed=42, device='cpu')

Model 1 | MAE: 97.321, RMSE: 142.713, R2: 0.671
Model 2 | MAE: 96.105, RMSE: 141.306, R2: 0.677
Model 3 | MAE: 101.242, RMSE: 143.305, R2: 0.668
Model 4 | MAE: 120.492, RMSE: 154.623, R2: 0.613
Model 5 | MAE: 119.472, RMSE: 161.688, R2: 0.577


In [12]:
#####################################################################
# Display the metrics for Baseline CNN model 
#####################################################################

Base_preds_list, base_std_list, base_targets_list = evaluate_models(
    ensemble_models=base_models_list, seed=42, device='cpu'
)

base_mae = mean_absolute_error(base_targets_list, Base_preds_list)
base_rmse = np.sqrt(mean_squared_error(base_targets_list, Base_preds_list))
base_r2 = r2_score(base_targets_list, Base_preds_list)
base_avg_uncertainty = base_std_list.mean()
print(f"*****************CNN Baseline Model*********************")
print(f"\nEvaluation results:")
print(f"MAE: {base_mae:.4f} \nRMSE: {base_rmse:.4f} \nR²: {base_r2:.4f}")
print(f"********************************************************")

*****************CNN Baseline Model*********************

Evaluation results:
MAE: 94.7015 
RMSE: 141.2704 
R²: 0.6773
********************************************************
