In [1]:
import pandas as pd
import numpy as np
import torch
from typing import Literal
import torch.nn as nn
import ast
import torch.optim as optim

import random

import torchvision.transforms as transforms
import torchvision.models as models

from IPython.display import clear_output

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.experimental import enable_iterative_imputer  # This is necessary to use IterativeImputer
from sklearn.impute import IterativeImputer
import copy

from scipy.stats import spearmanr, pearsonr
from scipy import stats
import pathlib

from tqdm import tqdm

import matplotlib.pyplot as plt

from IPython.display import clear_output


## Prepare dataset

In [None]:
input_folder = pathlib.Path('./training-data')

### Prepare clinical data

In [None]:
##### Load the clinical data stored in "clinical.csv" and impute missing data

clinical_df = pd.read_csv('./clinical.csv', index_col=0)

# Encode 'Sex' as a binary variable (1 for Male, 0 for Female)
clinical_df['Sex'] = np.where(clinical_df['Sex'] == 'Male', 1, clinical_df['Sex'])
clinical_df['Sex'] = np.where(clinical_df['Sex'] == 'Female', 0, clinical_df['Sex'])


# Use IterativeImputer instead of SimpleImputer
imputer = IterativeImputer(max_iter=10, random_state=0)

# Columns to impute (excluding 'PatientID')
columns_to_impute = clinical_df.columns.difference(['PatientID'])

# Use IterativeImputer
imputer = IterativeImputer(max_iter=10, random_state=0)

# Apply the imputer only to the columns that need imputation
clinical_df[columns_to_impute] = imputer.fit_transform(clinical_df[columns_to_impute])

clinical_df_orig = clinical_df.copy()
#clinical_df.to_csv("training-data/snapshots/BodyComposition_2/output/clinical_imputed.csv")

clinical_df

In [None]:
##### Load the body composition data stored in "body_composition.csv"

targets_df_new = pd.read_csv('./body_composition.csv', index_col = 0)
targets_df_new = targets_df_new.loc[targets_df_new.index.isin(clinical_df.PatientID)]
targets_df_new = targets_df_new.loc[targets_df_new.PersonId.isin(clinical_df.PatientID)]
targets_df_new = targets_df_new.rename(columns={'PersonId': 'PatientID'})
targets_df_new = pd.merge(clinical_df[['PatientID','Height_in_meters']], targets_df_new, on='PatientID', how='right')
targets_df_new = targets_df_new.set_index('PatientID')
#targets_df_new = targets_df_new.set_index('PersonId')
if 'VisceralFatArea' in targets_df_new.columns:
    targets_df_new['VisceralFatIndex'] = targets_df_new['VisceralFatArea']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['SkeletalMuscleIndex'] = targets_df_new['SkeletalMuscleArea']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['SubcutaneousFatIndex'] = targets_df_new['SubcutaneousFatArea']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['FatFreeIndex'] = targets_df_new['FatFreeArea']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
if 'VisceralFatVolume' in targets_df_new.columns:
    targets_df_new['VisceralFatIndex'] = targets_df_new['VisceralFatVolume']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['SkeletalMuscleIndex'] = targets_df_new['SkeletalMuscleVolume']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['SubcutaneousFatIndex'] = targets_df_new['SubcutaneousFatVolume']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
    targets_df_new['FatFreeIndex'] = targets_df_new['FatFreeVolume']/(targets_df_new['Height_in_meters']*targets_df_new['Height_in_meters'])
targets_df_new = targets_df_new.drop(columns = ['Height_in_meters'])
targets_df_new

In [None]:
#### Load the table with radiograph paths
radiograph_paths = pd.read_csv('./radiograph_paths',index_col = 0)
radiograph_paths['DX_path'] = radiograph_paths['DX_path']
radiograph_paths

### Train test split

In [None]:
# Create dataset and dataloaders
train_df, test_df = train_test_split(
    targets_df_new, test_size=0.2, random_state=25
)

train_df = targets_df_new.loc[targets_df_new.index.isin(train_df.index)]
test_df = targets_df_new.loc[targets_df_new.index.isin(test_df.index)]

### Confirm no shared Ids
assert np.sum(test_df.index.isin(train_df.index)) == 0
print('No data-leak detected. yay...')

train_df.head()

### Define dataset

In [7]:
def normalize_image(image):
    min_val = np.min(image)
    max_val = np.max(image)
    normalized_image = (image - min_val) / (max_val - min_val) * 255
    return normalized_image.astype(np.uint8)


def to_tensor(img):
    # Convert PIL Image to NumPy array
    img = np.array(img, dtype=np.float32)
    # Transpose dimensions to (C, H, W) and normalize pixel values to [0, 1]
    #img = torch.tensor(img.transpose((2, 0, 1)) / np.max(img) - np.min(img))
    return img

def pad_to_square(img):
    # Calculate the mean value of the image
    mean_val = img.mean()
    
    # Determine the size of padding
    h, w = img.shape
    if h > w:
        padding = ((0, 0), ((h - w) // 2, (h - w + 1) // 2))
    else:
        padding = (((w - h) // 2, (w - h + 1) // 2), (0, 0))
    #print(padding)
    # Pad the image to make it square
    img_padded = np.pad(img, padding, mode='constant', constant_values=mean_val)
    return img_padded

image_size = 512
# Define transformations
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(size=(image_size, image_size)),
    #transforms.Resize(size=(256, 256)),  # Resize the image to 256x256 pixels
    #transforms.RandomCrop(size=(224, 224)),  # Randomly crop the image to 224x224 pixels
    transforms.RandomRotation(degrees=10),  # Randomly rotate the image by up to 15 degrees
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    #transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),  # Random color jitter
    transforms.ToTensor(),  # Convert the image to a tensor
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(size=(image_size, image_size)),  # Resize the image to 256x256 pixels
    transforms.ToTensor(),  # Convert the image to a tensor
])

In [8]:
class InMemoryDictDatasetV2(Dataset):
    def __init__(self, targets_df: pd.DataFrame, cxr_paths_df: pd.DataFrame, clinical_df: pd.DataFrame, transform=None, pad = True, normalize: Literal['minmax', 'zscore'] = "minmax", random_image = True):
        self.targets_df = targets_df
        self.cxr_paths_df = cxr_paths_df
        self.clinical_df = clinical_df.set_index('PatientID')
        self.transform = transform
        self.normalize = normalize
        self.pad = pad
        self.target_cols = targets_df.columns.to_list()
        self.trimmed_targets_df = targets_df
        self.images = {}
        self.random_image = random_image
        for pt_id in tqdm(self.targets_df.index):
            if pt_id in self.cxr_paths_df.index.to_list():
                self.images[pt_id] = []
                np_files = self.cxr_paths_df.loc[pt_id].DX_good
                np_files_reverse = self.cxr_paths_df.loc[pt_id].DX_PA_reverse
                if np_files != None:
                    for file in np_files:
                        self.images[pt_id].append(np.load(file))
                if np_files_reverse != None:
                    for file in np_files_reverse:
                        self.images[pt_id].append(-1 * (np.load(file)))
            else:
                if pt_id in self.targets_df.index:
                    self.targets_df = self.targets_df.drop(index = pt_id)

    def __len__(self):
        return len(self.targets_df)

    def __getitem__(self, idx):
        pt_id = self.targets_df.index[idx]
        if self.random_image:
            image = random.choice(self.images[pt_id])
        else:
            image = self.images[pt_id][0]
        if self.normalize == "zscore":
            std = np.std(image)
            mean = np.mean(image)
            image = (image - mean)/std
        elif self.normalize == "minmax":
            image = (image - np.min(image)) / (np.max(image) - np.min(image))  # Normalize
        else:
            pass
        if self.pad:
            image = pad_to_square(image)
        if self.transform:
            image = self.transform(image)
        
        target = self.trimmed_targets_df.loc[pt_id].values.astype(np.float32)
        clinical_vars = torch.tensor(self.clinical_df.loc[pt_id].fillna(0).values, dtype=torch.float32).flatten()
        return {
            'image' : image,
            'target' : target,
            'clinical': clinical_vars,
            'id' : pt_id,
        }

    def set_target_columns(self, target_cols):
        self.target_cols =  target_cols
        self.trimmed_targets_df = self.targets_df[target_cols]

In [None]:
clinical_train_df = clinical_df.loc[clinical_df.PatientID.isin(train_df.index)]

scaler = StandardScaler()
list_cont_features = ['XRAge', 'Height_in_meters', 'Weight_in_kg']
scaler.fit(clinical_train_df[list_cont_features])
clinical_df[list_cont_features] = scaler.transform(clinical_df[list_cont_features])

target_scaler = StandardScaler()
target_scaler.fit(train_df)
scaled_train_array = target_scaler.transform(train_df)
scaled_train_df = pd.DataFrame(scaled_train_array, columns=train_df.columns, index=train_df.index)
scaled_test_array= target_scaler.transform(test_df)
scaled_test_df = pd.DataFrame(scaled_test_array, columns=train_df.columns, index=test_df.index)

train_minmax_dataset = InMemoryDictDatasetV2(scaled_train_df, cleaned_xr_v2, clinical_df, transform=transform, pad = False, normalize = "minmax", random_image = True)
test_minmax_dataset = InMemoryDictDatasetV2(scaled_test_df, cleaned_xr_v2, clinical_df, transform=test_transform, pad = False, normalize = "minmax", random_image = False)

train_zscore_dataset = InMemoryDictDatasetV2(scaled_train_df, cleaned_xr_v2, clinical_df, transform=transform, pad = False, normalize = "zscore", random_image = True)
test_zscore_dataset = InMemoryDictDatasetV2(scaled_test_df, cleaned_xr_v2, clinical_df, transform=test_transform, pad = False, normalize = "zscore", random_image = False)


# plt.imshow(normalize_image(test_dataset[0]['image'].squeeze().numpy()), cmap='gray')

In [None]:
print(f"""
train: {len(train_minmax_dataset)}
test: {len(test_minmax_dataset)}
""")

In [None]:
# sample image
# plt.imshow(normalize_image(test_dataset[2]['image'].squeeze().numpy()), cmap='gray')

## Define models

In [14]:
class RegressionModel_EarlyFusion(nn.Module):
    def __init__(self,
                 base_model,
                 output_size,
                 num_additional_features,
                 clin_dropout = 0,
                 image_dropout = 0,
                ):
        super(RegressionModel_EarlyFusion, self).__init__()
        self.model = base_model
        self.model.conv1 = nn.Conv2d(1, self.model.conv1.out_channels,
                                     kernel_size=self.model.conv1.kernel_size,
                                     stride=self.model.conv1.stride,
                                     padding=self.model.conv1.padding,
                                     bias=False) 


        num_resnet_features = self.model.fc.in_features
        self.clin_fc = nn.Linear(num_additional_features, image_size*image_size)
        
        self.fc = nn.Linear(num_resnet_features, output_size)
        self.model.fc = nn.Identity()
        
        self.clin_dropout = clin_dropout
        self.image_dropout = image_dropout

        
    def forward(self, x_image, x_clinical):
        clin_out = self.clin_fc(x_clinical)
        clin_out = clin_out.view(-1, 1, image_size, image_size)
        x_image = x_image.view(-1, 1, image_size, image_size)

        if self.training:            
            dropout_mask = torch.bernoulli(torch.full((clin_out.size(0), clin_out.size(1)), 1 - self.clin_dropout)).to(clin_out.device)
            dropout_mask = dropout_mask.unsqueeze(2).unsqueeze(3)  # Reshape mask to match (batch_size, num_channels, 1, 1)
            clin_out = clin_out * dropout_mask  # Apply mask to input tensor
           
        #combined = torch.cat((x_image, clin_out), dim=1)
        combined = x_image + clin_out
        resnet_out = self.model(combined)
        output = self.fc(resnet_out)
        #output = self.final_activation(output)
        return output

In [15]:
class RegressionModel_InterFusion(nn.Module):
    def __init__(self, base_model,
                 output_size,
                 num_additional_features,
                 clin_dropout = 0,
                 image_dropout = 0,
                 freeze_image_net = False,
                 concatenate_modes = False,
                ):
        super(RegressionModel_InterFusion, self).__init__()
        self.model = base_model
        self.model.conv1 = nn.Conv2d(1, self.model.conv1.out_channels,
                                     kernel_size=self.model.conv1.kernel_size,
                                     stride=self.model.conv1.stride,
                                     padding=self.model.conv1.padding,
                                     bias=False) 

        if freeze_image_net:
            for param in self.model.parameters():
                param.requires_grad = False
        num_resnet_features = self.model.fc.in_features
        #print(num_resnet_features)
        self.clin_fc = nn.Linear(num_additional_features, num_resnet_features)
        #self.fc = nn.Linear(num_resnet_features + num_additional_features, output_size)
        self.fc = nn.Linear(num_resnet_features, output_size)
        self.model.fc = nn.Identity()
        #self.final_activation = nn.LeakyReLU(0.1)
        self.clin_dropout = clin_dropout
        self.image_dropout = image_dropout
        self.concatenate_modes = concatenate_modes 
        if concatenate_modes:
            print("concatenating instead of adding.")
            self.fc_con = nn.Linear(num_resnet_features*2, output_size)
        
    def forward(self, x_image, x_clinical):
        if self.concatenate_modes:
            resnet_out = self.model(x_image)
            clin_fc_out = self.clin_fc(x_clinical)
            output = self.fc_con(torch.cat([resnet_out, clin_fc_out], dim =1))
        else:
            if (self.training and random.random()<self.clin_dropout) or self.clin_dropout == 1:
                #print('ignoring clinical data')
                resnet_out = self.model(x_image)
                output = self.fc(resnet_out)
            elif (self.training and random.random()<self.image_dropout) or self.image_dropout == 1:
                #print('ignoring imaging data')
                clin_fc_out = self.clin_fc(x_clinical)
                output = self.fc(clin_fc_out)
            else:
                #print('using both clinical and imaging data')
                resnet_out = self.model(x_image)
                clin_fc_out = self.clin_fc(x_clinical)
                output = self.fc(resnet_out+clin_fc_out)
            
        #output = self.final_activation(output)
        return output

In [None]:
class RegressionModel_LateFusion(nn.Module):
    def __init__(self, base_model,
                 output_size,
                 num_additional_features,
                 clin_dropout = 0,
                 image_dropout = 0,
                 freeze_image_net = False,
                 concatenate_modes = False,
                ):
        super(RegressionModel_LateFusion, self).__init__()
        self.model = base_model
        self.model.conv1 = nn.Conv2d(1, self.model.conv1.out_channels,
                                     kernel_size=self.model.conv1.kernel_size,
                                     stride=self.model.conv1.stride,
                                     padding=self.model.conv1.padding,
                                     bias=False) 

        if freeze_image_net:
            for param in self.model.parameters():
                param.requires_grad = False
        num_resnet_features = self.model.fc.in_features
        #print(num_resnet_features)
        self.clin_fc = nn.Linear(num_additional_features, num_resnet_features)
        #self.fc = nn.Linear(num_resnet_features + num_additional_features, output_size)
        self.fc = nn.Linear(num_resnet_features, output_size)
        self.model.fc = nn.Identity()
        #self.final_activation = nn.LeakyReLU(0.1)
        self.clin_dropout = clin_dropout
        self.image_dropout = image_dropout
        self.concatenate_modes = concatenate_modes 
        if concatenate_modes:
            print("concatenating instead of adding.")
            self.fc_con = nn.Linear(num_resnet_features*2, output_size)
        
    def forward(self, x_image, x_clinical):
        if self.concatenate_modes:
            resnet_out = self.model(x_image)
            clin_fc_out = self.clin_fc(x_clinical)
            output = self.fc_con(torch.cat([resnet_out, clin_fc_out], dim =1))
        else:
            if (self.training and random.random()<self.clin_dropout) or self.clin_dropout == 1:
                #print('ignoring clinical data')
                resnet_out = self.model(x_image)
                output = self.fc(resnet_out)
            elif (self.training and random.random()<self.image_dropout) or self.image_dropout == 1:
                #print('ignoring imaging data')
                clin_fc_out = self.clin_fc(x_clinical)
                output = self.fc(clin_fc_out)
            else:
                #print('using both clinical and imaging data')
                resnet_out = self.model(x_image)
                clin_fc_out = self.clin_fc(x_clinical)
                output = self.fc(resnet_out+clin_fc_out)
            
        #output = self.final_activation(output)
        return output


class LateFusionFC(nn.Module):
    def __init__(self, n):
        super(LateFusionFC, self).__init__()
        # Define a fully connected layer
        self.fc = nn.Linear(2 * n, n)
        
    def forward(self, x):
        # Forward pass through the fully connected layer
        x = self.fc(x)
        return x

In [None]:
def pearsonr_ci(x,y,alpha=0.05):
    ''' calculate Pearson correlation along with the confidence interval using scipy and numpy
    Parameters
    ----------
    x, y : iterable object such as a list or np.array
      Input for correlation calculation
    alpha : float
      Significance level. 0.05 by default
    Returns
    -------
    r : float
      Pearson's correlation coefficient
    pval : float
      The corresponding p value
    lo, hi : float
      The lower and upper bound of confidence intervals
    '''

    r, p = pearsonr(x,y)
    r_z = np.arctanh(r)
    se = 1/np.sqrt(x.size-3)
    z = stats.norm.ppf(1-alpha/2)
    lo_z, hi_z = r_z-z*se, r_z+z*se
    lo, hi = np.tanh((lo_z, hi_z))
    return r, p, lo, hi

In [None]:
def get_model_by_name(model_name: str, pretrained=False):
    """
    Returns the corresponding model object given a model name.
    
    Parameters:
    - model_name (str): Name of the model (e.g., 'resnet18', 'resnet50').
    - pretrained (bool): If True, returns a model pre-trained on ImageNet. Default is False.
    
    Returns:
    - model: The corresponding PyTorch model.
    """
    model_dict = {
        'resnet18': models.resnet18,
        'resnet34': models.resnet34,
        'resnet50': models.resnet50,
        'resnet101': models.resnet101,
        'resnet152': models.resnet152,
        'alexnet': models.alexnet,
        'vgg16': models.vgg16,
        'squeezenet': models.squeezenet1_0,
        'densenet': models.densenet121,
        'inception': models.inception_v3,
        'googlenet': models.googlenet,
        'shufflenet': models.shufflenet_v2_x1_0,
        'mobilenet': models.mobilenet_v2,
        'resnext50_32x4d': models.resnext50_32x4d,
        'resnext101_32x8d': models.resnext101_32x8d,
        'wide_resnet50_2': models.wide_resnet50_2,
        'wide_resnet101_2': models.wide_resnet101_2,
        'mnasnet': models.mnasnet1_0,
        # Add more models as needed
    }
    
    if model_name in model_dict:
        return model_dict[model_name](pretrained=pretrained)
    else:
        raise ValueError(f"Model '{model_name}' is not recognized. Available models are: {list(model_dict.keys())}")


## Evaluation logic

In [None]:

def load_model_weights(model, model_weights: pathlib.Path):
    try:
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(torch.load(model_weights, weights_only=True), strict=False)
        else:
            state_dict = torch.load(model_weights, weights_only=True)
            if len(train_df.columns)!=9:
                del state_dict['fc.weight']
                del state_dict['fc.bias']
            model.load_state_dict(state_dict, strict=False)
        print(f'Weights successfully loaded from {model_weights}')
    except Exception as e:
        print(f"Failed to load imaging weights.\n{e}")

def load_clinical_model_weights(model, clinical_model_weights: pathlib.Path, subset=True):
    try:
        clinical_state_dict = torch.load(clinical_model_weights, weights_only=True)
        
        if subset:
            clin_fc_state_dict = {k: v for k, v in clinical_state_dict.items() if 'clin_fc' in k}
            model.load_state_dict(clin_fc_state_dict, strict =False)
        else:
            # for late
            model.load_state_dict(clinical_state_dict, strict=False)
            
        print(f'Clinical weights successfully loaded from {clinical_model_weights}')
    except Exception as e:
        print(f"Failed to load clinical weights.\n{e}")

class SimpleModelWrapper:
    def __init__(self, model):
        self.model = model
    
    def to_device(self, device):
        self.model = nn.DataParallel(self.model).to(device)
        return self.model

    def evaluate(self, model, inputs, clinical_vars):
        return model(inputs, clinical_vars)

    def combine(self, all_targets):
        return np.concatenate(all_targets, axis=0)

class LateModelWrapper:
    def __init__(self, imaging_model, clinical_model, model):
        self.imaging_model = imaging_model
        self.clinical_model = clinical_model
        self.model = model

    def to_device(self,device):
        for param in self.imaging_model.parameters():
            param.requires_grad = False
        for param in self.clinical_model.parameters():
            param.requires_grad = False

        self.model = nn.DataParallel(self.model).to(device)
        self.clinical_model = nn.DataParallel(self.clinical_model).to(device)
        self.imaging_model = nn.DataParallel(self.imaging_model).to(device)
        return self.model
    
    def evaluate(self, model, inputs, clinical_vars):
        imaging_outputs = self.imaging_model(inputs, clinical_vars)
        clinical_outputs = self.clinical_model(inputs, clinical_vars)
        inputs_concat = torch.cat([imaging_outputs, clinical_outputs], dim = 1)
        return model(inputs_concat)
    
    def combine(self, all_targets):
        all_targets = np.concatenate(all_targets, axis=0)
        return target_scaler.inverse_transform(all_targets)

In [None]:
# Hyperparameter tuning with cross-validation
def train_and_evaluate(
    wrapper: SimpleModelWrapper | LateModelWrapper,
    output_folder: pathlib.Path,
    train_dataset,
    test_dataset,
    loss,
    epochs=25,
    lr=1e-3,
    weight_decay=1e-5,
    run_name: str = "",
    batch_size = 16):

    if run_name == "HP_tuning":
        run_name = f"hp_multimodal_lr-{lr}_wd-{weight_decay}_bs-{batch_size}_resnet18"

    output_folder.mkdir(exist_ok=True)
    output_model_path = output_folder / f"{run_name}_model.pth"
    if output_model_path.exist():
        run_name = f"{run_name}_1"
        output_model_path = output_folder / f"{run_name}_model.pth"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = wrapper.to_device(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = loss()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    best_corrs = {}
    train_losses = []
    test_losses = []

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for input_dict in tqdm(train_loader):
            inputs = input_dict['image']
            targets = input_dict['target']
            clinical_vars = input_dict['clinical']
            #clinical_vars = torch.zeros(clinical_vars.shape)
            clinical_vars = clinical_vars.to(device)
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = wrapper.evaluate(model, inputs, clinical_vars)

            #if torch.isnan(outputs).any() or torch.isnan(targets).any(): 
                #print("inputs", inputs, clinical_vars)
                #print("outputs", outputs, "\ntargets", targets, "\n")
            loss = criterion(outputs, targets)
            if torch.isnan(loss).any(): 
                pass
                #print(outputs, targets)
            else:
                running_loss += loss.item() * inputs.size(0)
            
            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        
        # Evaluate on test set
        model.eval()
        test_loss = 0.0
        all_targets = []
        all_outputs = []
        with torch.no_grad():
            for input_dict in tqdm(test_loader):
                inputs = input_dict['image']
                targets = input_dict['target']
                clinical_vars = input_dict['clinical']
                clinical_vars = clinical_vars.to(device)
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = wrapper(inputs, clinical_vars)
    
                #if torch.isnan(outputs).any() or torch.isnan(targets).any(): 
                    #print("inputs", inputs, clinical_vars)
                    #print("outputs", outputs, "\ntargets", targets, "\n")
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                all_targets.append(targets.cpu().numpy())
                all_outputs.append(outputs.cpu().numpy())
        
        test_loss = test_loss / len(test_loader.dataset)
        test_losses.append(test_loss)
        best_epoch = False
        if test_loss < best_loss:
            best_epoch = True
            best_loss = test_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), output_model_path)
            else:
                torch.save(model.state_dict(), output_model_path)


        all_targets = wrapper.combine(all_targets)
        all_outputs = wrapper.combine(all_outputs)
        # Calculate Spearman correlation
        spearman_corrs = []
        pearson_corrs = []
        spearman_pvals = []
        pearson_pvals = []
        pearson_low = []
        pearson_high = []
        for i in range(all_targets.shape[1]):
            corr, p_val = spearmanr(all_targets[:, i], all_outputs[:, i])
            pcorr, pp_val = pearsonr(all_targets[:, i], all_outputs[:, i])
            pcorr_v2, p_val_v2, low, high = pearsonr_ci(all_targets[:, i], all_outputs[:, i])
            spearman_corrs.append(corr)
            pearson_corrs.append(pcorr)
            spearman_pvals.append(p_val)
            pearson_pvals.append(pp_val)
            pearson_low.append(low)
            pearson_high.append(high)
        #print(f"Spearman Correlations: {spearman_corrs}")

        clear_output(wait=False)
        print(run_name)
        
        # Plotting 3x3 scatter plot matrix
        num_rows = int(np.ceil(all_targets.shape[1]/3))
        fig, axes = plt.subplots(num_rows, 3, figsize=(15, num_rows*5))
        axes = axes.flatten()
        # Plotting 3x3 scatter plot matrix
        for i in range(all_targets.shape[1]):
            ax = axes[i]
            ax.scatter(all_targets[:, i], all_outputs[:, i], alpha=0.5)
            
            # Set same scale for both x and y axes
            min_val = min(all_targets[:, i].min(), all_outputs[:, i].min())
            max_val = max(all_targets[:, i].max(), all_outputs[:, i].max())
            ax.set_xlim([min_val, max_val])
            ax.set_ylim([min_val, max_val])
            
            # Draw a regression line
            z = np.polyfit(all_targets[:, i], all_outputs[:, i], 1)
            p = np.poly1d(z)
            ax.plot([min_val, max_val], p([min_val, max_val]), "r--")
        
            ax.set_title(f'{train_dataset.target_cols[i]}\nTarget vs Prediction\nSpearman: {spearman_corrs[i]:.2f} ({spearman_pvals[i]:.2f})\nPearson: {pearson_corrs[i]:.2f} ({pearson_pvals[i]:.2f}, 95CI: {pearson_low[i]:.2f}:{pearson_high[i]:.2f})')
            ax.set_xlabel('True Values')
            ax.set_ylabel('Predictions')
        
        # Hide unused subplots
        for i in range(all_targets.shape[1], len(axes)):
            fig.delaxes(axes[i])
        
        plt.tight_layout()
        plt.savefig(output_folder / f"{run_name}Predictions.png")
        if best_epoch: 
            plt.savefig(output_folder / f"{run_name}Predictions_Best.png")
            for i, col in enumerate(train_dataset.target_cols):
                best_corrs[col] = {
                    "pearson" : pearson_corrs[i],
                    "spearman" : spearman_corrs[i],
                }
            
        plt.show()
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, Test Loss: {test_loss:.4f}')
        plt.plot(range(len(train_losses)), train_losses, label='Train Loss', color = "blue")
        plt.plot(range(len(test_losses)), test_losses, label='Test Loss', color = "red")
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(output_folder / f"{run_name}lossCurves.png")
        plt.show()
    
    if isinstance(model, nn.DataParallel):
        torch.save(model.module.state_dict(),  output_folder/ f"{run_name}_model_final.pth")
        torch.save(model.module,  output_folder / f"{run_name}_model_final_full.pth")
    else:
        torch.save(model.state_dict(),  output_folder / f"{run_name}_model_final.pth")

    model.load_state_dict(best_model_wts)
    return {
        'model' : model,
        'best_loss' : best_loss,
        'best_corrs' : best_corrs
    }


In [None]:
huber_loss = nn.SmoothL1Loss

## Early

In [None]:
output_folder = pathlib.Path("./output")
output_size = len(train_minmax_dataset.target_cols)

model = RegressionModel_EarlyFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 0.2,
    image_dropout = 0,
)
# load_model_weights(model, output_folder / 'imaging_orig_vol_test_v1_model_final.pth')
# load_clinical_model_weights(model, output_folder / 'clinical_orig_vol_test_v1_model_final.pth')

trained_model = train_and_evaluate(
    SimpleModelWrapper(model),
    output_folder,
    train_minmax_dataset, 
    test_minmax_dataset, 
    loss = huber_loss, 
    epochs=30, 
    lr=1e-05, 
    weight_decay=3.162277660168379e-08, 
    run_name = "early_orig_l3_test_v1",
    batch_size = 16,
)

## Inter clinical only

In [None]:
output_folder = pathlib.Path("./output")
output_size = len(train_zscore_dataset.target_cols)

model = RegressionModel_InterFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 0,
    image_dropout = 1,
    freeze_image_net = False,
    concatenate_modes = False,
    )

trained_model = train_and_evaluate(
    SimpleModelWrapper(model),
    output_folder,
    train_zscore_dataset, 
    test_zscore_dataset, 
    loss = huber_loss, 
    epochs=10, 
    lr=1e-4, 
    weight_decay=1e-8, 
    run_name = "clinical_orig_l3_test_v1",
    batch_size = 16,
)

## Inter Image only

In [None]:
output_folder = pathlib.Path("./output")
output_size = len(train_zscore_dataset.target_cols)

model = RegressionModel_InterFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 1,
    image_dropout = 0,
    freeze_image_net = False,
    concatenate_modes = False,
    )
# load_model_weights(model, pathlib.Path(output_folder / 'image_only_resnet18_scaled_v9_model.pth'))

trained_model = train_and_evaluate(
    SimpleModelWrapper(model),
    output_folder,
    train_zscore_dataset, 
    test_zscore_dataset, 
    loss = huber_loss, 
    epochs=23, 
    lr= 1.584893192461114e-05,
    weight_decay= 6.30957344480193e-08,
    run_name = "imaging_orig_l3_test_v1",
    batch_size = 16,
)


## Inter both

In [None]:
output_folder = pathlib.Path("./output")
output_size = len(train_zscore_dataset.target_cols)

model = RegressionModel_InterFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 0,
    image_dropout = 0,
    freeze_image_net = False,
    concatenate_modes = True,
    )
# load_model_weights(model, pathlib.Path(output_folder / 'imaging_orig_l3_model_final.pth'))
load_clinical_model_weights(model, pathlib.Path(output_folder / 'clinical_orig_vol_test_v1_model_final.pth'))

trained_model = train_and_evaluate(
    SimpleModelWrapper(model),
    output_folder,
    train_zscore_dataset, 
    test_zscore_dataset, 
    loss = huber_loss, 
    epochs=30, 
    lr=1.584893192461114e-05, 
    weight_decay=6.30957344480193e-08, 
    run_name = "multimodal_orig_l3_test",
    batch_size = 16,
)

## Late

In [None]:
output_folder = pathlib.Path("./output")
output_size = len(train_minmax_dataset.target_cols)

imaging_model = RegressionModel_LateFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 1,
    image_dropout = 0,
    freeze_image_net = False,
    concatenate_modes = False,
)
load_model_weights(imaging_model,  output_folder / 'imaging_orig_vol_test_final_v3_model_final.pth')

clinical_model = RegressionModel_LateFusion(
    get_model_by_name('resnet18'),
    output_size,
    num_additional_features = 4,
    clin_dropout = 0,
    image_dropout = 1,
    freeze_image_net = False,
    concatenate_modes = False,
)
load_clinical_model_weights(clinical_model,  output_folder / 'clinical_orig_vol_test_v1_model_final.pth', False)

model = LateFusionFC(len(targets_df_new.columns))

trained_model = train_and_evaluate(
    LateModelWrapper(imaging_model, clinical_model, model),
    output_folder,
    train_minmax_dataset, 
    test_minmax_dataset, 
    loss = huber_loss, 
    epochs=30, 
    lr=1e-03, 
    weight_decay=1e-07, 
    run_name = "orig_multimodal_late_Vol_repeat_v2",
    batch_size = 16,
)