In [1]:
import pandas as pd
import numpy as np

# File paths
data_dir = "./brca_metabric/"  # Directory containing the files

# Metadata and file mapping
files = {
    "clinical_patient": "data_clinical_patient.txt",
    "clinical_sample": "data_clinical_sample.txt",
    "cna": "data_cna.txt",
    "methylation": "data_methylation_promoters_rrbs.txt",
    "mrna_expression": "data_mrna_illumina_microarray.txt",
    "mrna_zscores": "data_mrna_illumina_microarray_zscores_ref_diploid_samples.txt",
    "mutations": "data_mutations.txt",
}

# Load clinical data
def load_clinical_data():
    # Load the clinical data files, skipping metadata rows (first few rows)
    clinical_patient = pd.read_csv(data_dir + files["clinical_patient"], sep="\t", skiprows=4)  # Skip metadata rows
    clinical_sample = pd.read_csv(data_dir + files["clinical_sample"], sep="\t", skiprows=4)  # Skip metadata rows
    
    # Strip spaces from column names
    clinical_patient.columns = clinical_patient.columns.str.strip()
    clinical_sample.columns = clinical_sample.columns.str.strip()
    
    # Check for the presence of 'PATIENT_ID' column in both files
    print("Clinical patient columns:", clinical_patient.columns)
    print("Clinical sample columns:", clinical_sample.columns)
    
    # Merge patient and sample data on the patient identifier
    if 'PATIENT_ID' in clinical_patient.columns and 'PATIENT_ID' in clinical_sample.columns:
        clinical_data = pd.merge(clinical_patient, clinical_sample, on="PATIENT_ID", how="outer")
        print("Clinical data shape:", clinical_data.shape)
        return clinical_data
    else:
        raise KeyError("PATIENT_ID column is missing from one or both clinical data files")

# Load copy number alteration (CNA) data
def load_cna_data():
    cna_data = pd.read_csv(data_dir + files["cna"], sep="\t", comment="#", index_col=0)
    print("Initial CNA data shape:", cna_data.shape)

    # Transpose the data
    cna_data = cna_data.T
    print("Transposed CNA data shape:", cna_data.shape)
    return cna_data

# Load mRNA expression data
def load_mrna_expression_data():
    mrna_data = pd.read_csv(data_dir + files["mrna_expression"], sep="\t", comment="#", index_col=0)
    print("Initial mRNA data shape:", mrna_data.shape)
    
    # Transpose the data
    mrna_data = mrna_data.drop(columns=["Entrez_Gene_Id"]).T  # Remove `Entrez_Gene_Id` before transposing
    print("Transposed mRNA data shape:", mrna_data.shape)
    return mrna_data

# Load mutation data
def load_mutation_data():
    mutation_data = pd.read_csv(data_dir + files["mutations"], sep="\t", comment="#")
    print("Mutation data shape:", mutation_data.shape)
    return mutation_data

def check_mergeability(clinical_data, cna_data, mrna_data):
    # Normalize patient identifiers in all datasets
    clinical_data['PATIENT_ID'] = clinical_data['PATIENT_ID'].str.strip().str.upper()
    cna_data.index = cna_data.index.str.strip().str.upper()
    mrna_data.index = mrna_data.index.str.strip().str.upper()
    
    # Convert clinical_data to use PATIENT_ID as the index for consistency
    clinical_index = set(clinical_data['PATIENT_ID'])
    cna_index = set(cna_data.index)
    mrna_index = set(mrna_data.index)

    print("Clinical data patient IDs:", len(clinical_index))
    print("CNA data patient IDs:", len(cna_index))
    print("mRNA data patient IDs:", len(mrna_index))
    
    # Check intersections
    clinical_cna_intersection = clinical_index.intersection(cna_index)
    all_intersection = clinical_cna_intersection.intersection(mrna_index)
    
    # Print debugging information
    print(f"Patients in clinical and CNA data: {len(clinical_cna_intersection)}")
    print(f"Patients in all three datasets: {len(all_intersection)}")
    
    # Check if merge is feasible
    if len(all_intersection) > 0:
        print("Datasets are mergeable!")
    else:
        print("Datasets are not mergeable. Check your patient identifiers.")
    
    return all_intersection


# Preprocess clinical data
def preprocess_clinical_data(data):
    # Handle missing values
    data = data.fillna("Unknown")
    # Exclude PATIENT_ID from one-hot encoding
    categorical_columns = data.select_dtypes(include=['object']).columns
    categorical_columns = categorical_columns.drop('PATIENT_ID', errors='ignore')
    # One-hot encode categorical variables
    data = pd.get_dummies(data, columns=categorical_columns, drop_first=True)
    return data

# Merge datasets
def merge_datasets(clinical_data, cna_data, mrna_data):
    # Ensure consistent identifiers
    clinical_data['PATIENT_ID'] = clinical_data['PATIENT_ID'].str.strip().str.upper()
    cna_data.index = cna_data.index.str.strip().str.upper()
    mrna_data.index = mrna_data.index.str.strip().str.upper()

    # Merge datasets
    clinical_data = clinical_data.set_index("PATIENT_ID")
    merged_data = clinical_data.join(cna_data, how="inner", rsuffix="_CNA").join(mrna_data, how="inner", rsuffix="_mRNA")
    
    print("Merged data shape:", merged_data.shape)
    return merged_data

# Main pipeline
def main_pipeline():
    # Load datasets
    clinical_data = load_clinical_data()
    cna_data = load_cna_data()
    mrna_data = load_mrna_expression_data()
    mutation_data = load_mutation_data()

    # Check if datasets are mergeable
    mergeable_ids = check_mergeability(clinical_data, cna_data, mrna_data)
    if not mergeable_ids:
        raise ValueError("No common patient IDs across the datasets. Cannot merge.")

    # Preprocess clinical data
    clinical_data = preprocess_clinical_data(clinical_data)

    # Merge datasets
    merged_data = merge_datasets(clinical_data, cna_data, mrna_data)

    # Save preprocessed data
    merged_data.to_csv(data_dir + "preprocessed_metabric.csv", index=True)
    print("Preprocessed data saved as 'preprocessed_metabric.csv'")
    return merged_data

# Run the pipeline
preprocessed_data = main_pipeline()


Clinical patient columns: Index(['PATIENT_ID', 'LYMPH_NODES_EXAMINED_POSITIVE', 'NPI', 'CELLULARITY',
       'CHEMOTHERAPY', 'COHORT', 'ER_IHC', 'HER2_SNP6', 'HORMONE_THERAPY',
       'INFERRED_MENOPAUSAL_STATE', 'SEX', 'INTCLUST', 'AGE_AT_DIAGNOSIS',
       'OS_MONTHS', 'OS_STATUS', 'CLAUDIN_SUBTYPE', 'THREEGENE',
       'VITAL_STATUS', 'LATERALITY', 'RADIO_THERAPY', 'HISTOLOGICAL_SUBTYPE',
       'BREAST_SURGERY', 'RFS_MONTHS', 'RFS_STATUS'],
      dtype='object')
Clinical sample columns: Index(['PATIENT_ID', 'SAMPLE_ID', 'CANCER_TYPE', 'CANCER_TYPE_DETAILED',
       'ER_STATUS', 'HER2_STATUS', 'GRADE', 'ONCOTREE_CODE', 'PR_STATUS',
       'SAMPLE_TYPE', 'TUMOR_SIZE', 'TUMOR_STAGE', 'TMB_NONSYNONYMOUS'],
      dtype='object')
Clinical data shape: (2509, 36)
Initial CNA data shape: (22544, 2174)
Transposed CNA data shape: (2174, 22544)
Initial mRNA data shape: (20603, 1981)
Transposed mRNA data shape: (1980, 20603)
Mutation data shape: (17272, 45)
Clinical data patient IDs: 2509
CNA d

In [9]:
# Calculate evaluation metrics
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, confusion_matrix, matthews_corrcoef, average_precision_score

def calculate_metrics(all_labels, all_predictions):
    accuracy = accuracy_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions)
    auc = roc_auc_score(all_labels, all_predictions)
    prec = precision_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    cm = confusion_matrix(all_labels, all_predictions)
    mcc = matthews_corrcoef(all_labels, all_predictions)
    auprc = average_precision_score(all_labels, all_predictions)
    specificity = cm[0, 0] / (cm[0, 0] + cm[0, 1])

    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"AUC: {auc:.4f}")
    print(f"AUPRC: {auprc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"MCC: {mcc:.4f}")



In [2]:
# Process and save the data
import pandas as pd
import numpy as np

# File paths
data_dir = "./brca_metabric/"  # Directory containing the files

# Metadata and file mapping
files = {
    "clinical_patient": "data_clinical_patient.txt",
    "clinical_sample": "data_clinical_sample.txt",
    "cna": "data_cna.txt",
    "mrna_expression": "data_mrna_illumina_microarray.txt",
    "mrna_zscores": "data_mrna_illumina_microarray_zscores_ref_diploid_samples.txt",
}

# Define survival threshold for binary classification
SURVIVAL_THRESHOLD = 60  # 60 months

def preprocess_clinical_data(clinical_data):
    """Preprocess clinical data: one-hot encoding, handle missing values."""
    clinical_data = clinical_data.fillna("Unknown")  # Handle missing values
    columns_to_drop = [
    "OS_STATUS", "RFS_MONTHS", "RFS_STATUS"
    , "VITAL_STATUS", "CHEMOTHERAPY", "HORMONE_THERAPY", "RADIO_THERAPY"
    , "SAMPLE_ID", "CANCER_TYPE", "CANCER_TYPE_DETAILED", "ONCOTREE_CODE"
    , "SAMPLE_TYPE"
    , "TMB_NONSYNONYMOUS", "HER2_SNP6"
    ]

    clinical_data = clinical_data.drop(columns=columns_to_drop)

    categorical_columns = clinical_data.select_dtypes(include=["object"]).columns
    categorical_columns = categorical_columns.drop("PATIENT_ID", errors="ignore")

    # One-hot encode categorical variables
    clinical_data = pd.get_dummies(clinical_data, columns=categorical_columns, drop_first=True)
    return clinical_data

def load_and_process_data():
    # Load clinical data
    clinical_patient = pd.read_csv(data_dir + files["clinical_patient"], sep="\t", skiprows=4)
    clinical_sample = pd.read_csv(data_dir + files["clinical_sample"], sep="\t", skiprows=4)
    clinical_patient.columns = clinical_patient.columns.str.strip()
    clinical_sample.columns = clinical_sample.columns.str.strip()

    # Merge clinical data
    clinical_data = pd.merge(clinical_patient, clinical_sample, on="PATIENT_ID", how="outer")

    # Load CNA data
    cna_data = pd.read_csv(data_dir + files["cna"], sep="\t", comment="#", index_col=0).T

    # Load mRNA data
    mrna_data = pd.read_csv(data_dir + files["mrna_expression"], sep="\t", comment="#", index_col=0)
    mrna_data = mrna_data.drop(columns=["Entrez_Gene_Id"]).T  # Remove non-feature columns

    # load mrna zscores
    mrna_zscores = pd.read_csv(data_dir + files["mrna_zscores"], sep="\t", comment="#", index_col=0)
    mrna_zscores = mrna_zscores.drop(columns=["Entrez_Gene_Id"]).T  # Remove non-feature columns

    # Ensure consistent identifiers
    clinical_data["PATIENT_ID"] = clinical_data["PATIENT_ID"].str.strip().str.upper()
    cna_data.index = cna_data.index.str.strip().str.upper()
    mrna_data.index = mrna_data.index.str.strip().str.upper()
    mrna_zscores.index = mrna_zscores.index.str.strip().str.upper()

    # Align data based on common patient IDs
    common_ids = sorted(list(set(clinical_data["PATIENT_ID"]).intersection(cna_data.index, mrna_data.index)))  # Sorted
    clinical_data = clinical_data[clinical_data["PATIENT_ID"].isin(common_ids)].sort_values(by="PATIENT_ID")
    cna_data = cna_data.loc[common_ids].sort_index()
    mrna_data = mrna_data.loc[common_ids].sort_index()
    mrna_zscores = mrna_zscores.loc[common_ids].sort_index()

    # Preprocess clinical data
    clinical_data = preprocess_clinical_data(clinical_data)

    # Extract the label for survival prediction
    labels = clinical_data[["OS_MONTHS"]].copy()
    labels["SURVIVAL_BINARY"] = (labels["OS_MONTHS"] >= SURVIVAL_THRESHOLD).astype(int)
    labels["PATIENT_ID"] = clinical_data.index  # Add PATIENT_ID as a column
    labels = labels[["PATIENT_ID", "SURVIVAL_BINARY"]]  # Include PATIENT_ID for alignment

    # Drop unnecessary columns from clinical data
    clinical_data = clinical_data.drop(columns=["OS_MONTHS"], errors="ignore")

    # Save processed files
    clinical_data.to_csv(data_dir + "dataset/clinical_data.csv", index=True)
    cna_data.to_csv(data_dir + "dataset/cna_data.csv", index=True)
    mrna_data.to_csv(data_dir + "dataset/mrna_data.csv", index=True)
    mrna_zscores.to_csv(data_dir + "dataset/mrna_zscores.csv", index=True)
    labels.to_csv(data_dir + "dataset/labels.csv", index=False)

    print("Processed files saved:")
    print(" - clinical_data.csv")
    print(" - cna_data.csv")
    print(" - mrna_data.csv")
    print(" - mrna_zscores.csv")
    print(" - labels.csv")

# Run the processing pipeline
load_and_process_data()



Processed files saved:
 - clinical_data.csv
 - cna_data.csv
 - mrna_data.csv
 - mrna_zscores.csv
 - labels.csv


In [1]:
# Dataset class for Breast Cancer data
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class BreastCancerDataset(Dataset):
    def __init__(self, clinical_path, cna_path, mrna_path, labels_path):
        # Load data
        self.clinical_data = pd.read_csv(clinical_path)
        self.cna_data = pd.read_csv(cna_path, index_col=0)  # Load CNA with index as PATIENT_ID
        self.mrna_data = pd.read_csv(mrna_path, index_col=0)  # Load mRNA with index as PATIENT_ID
        self.labels = pd.read_csv(labels_path)
        
        # Drop the PATIENT_ID column from clinical data if it exists
        if "PATIENT_ID" in self.clinical_data.columns:
            self.clinical_data = self.clinical_data.drop(columns=["PATIENT_ID"])
        
        # Reset the index for CNA and mRNA data to avoid including PATIENT_ID
        self.cna_data = self.cna_data.reset_index(drop=True)
        self.mrna_data = self.mrna_data.reset_index(drop=True)
        
        # Convert data to numeric arrays
        self.clinical_data = pd.get_dummies(self.clinical_data, drop_first=True).to_numpy(dtype=np.float32)
        self.cna_data = self.cna_data.to_numpy(dtype=np.float32)
        self.mrna_data = self.mrna_data.to_numpy(dtype=np.float32)
        
        # Ensure labels are aligned and converted to numeric arrays
        self.labels = self.labels["SURVIVAL_BINARY"].to_numpy(dtype=np.float32)  # Use only the label column

        print("Clinical data shape:", self.clinical_data.shape)
        print("CNA data shape:", self.cna_data.shape)
        print("mRNA data shape:", self.mrna_data.shape)
        print("Labels shape:", self.labels.shape)       

    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        clinical = torch.tensor(self.clinical_data[idx], dtype=torch.float32)
        cna = torch.tensor(self.cna_data[idx], dtype=torch.float32)
        mrna = torch.tensor(self.mrna_data[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return clinical, cna, mrna, label


In [2]:
# Load the dataset into a DataLoader
from torch.utils.data import DataLoader

# Paths to preprocessed files
clinical_path = "brca_metabric/dataset/clinical_data.csv"
cna_path = "brca_metabric/dataset/cna_data.csv"
mrna_path = "brca_metabric/dataset/mrna_zscores.csv"
labels_path = "brca_metabric/dataset/labels.csv"

# Create dataset and DataLoader
dataset = BreastCancerDataset(clinical_path, cna_path, mrna_path, labels_path)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Test the DataLoader
for clinical, cna, mrna, label in dataloader:
    print(f"Clinical shape: {clinical.shape}")
    print(f"CNA shape: {cna.shape}")
    print(f"mRNA shape: {mrna.shape}")
    print(f"Labels shape: {label.shape}")
    break


Clinical data shape: (1980, 196)
CNA data shape: (1980, 22544)
mRNA data shape: (1980, 20603)
Labels shape: (1980,)
Clinical shape: torch.Size([32, 196])
CNA shape: torch.Size([32, 22544])
mRNA shape: torch.Size([32, 20603])
Labels shape: torch.Size([32])


In [3]:
# Split the dataset into training, validation and testing sets and remove NaNs

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, DataLoader
import numpy as np

# Check and remove NaNs from dataset
print("Clinical data NaNs:", np.any(np.isnan(dataset.clinical_data)))
print("CNA data NaNs:", np.any(np.isnan(dataset.cna_data)))
print("mRNA data NaNs:", np.any(np.isnan(dataset.mrna_data)))
print("Labels NaNs:", np.any(np.isnan(dataset.labels)))

# Remove columns with NaNs in CNA and mRNA data
cna_nan_columns = np.isnan(dataset.cna_data).any(axis=0)
dataset.cna_data = dataset.cna_data[:, ~cna_nan_columns]
mrna_nan_columns = np.isnan(dataset.mrna_data).any(axis=0)
dataset.mrna_data = dataset.mrna_data[:, ~mrna_nan_columns]

# Split dataset into train (70%), validation (10%), and test (20%) sets
dataset_size = len(dataset)
indices = np.arange(dataset_size)

# Perform the first split to separate training data
train_indices, test_indices, train_labels, temp_labels = train_test_split(
    indices, dataset.labels, test_size=0.15, random_state=42, stratify=dataset.labels
)

# Perform the second split to divide validation and test data
# val_indices, test_indices = train_test_split(
#     test_indices, test_size=0.5, random_state=42, stratify=temp_labels
# )

# Create subsets for each split
train_dataset = Subset(dataset, train_indices)
# val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

# Create DataLoaders for each subset
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Print dataset statistics
print(f"Training samples: {len(train_dataset)}")
# print(f"Validation samples: {len(val_dataset)}")
print(f"Testing samples: {len(test_dataset)}")



Clinical data NaNs: False
CNA data NaNs: True
mRNA data NaNs: True
Labels NaNs: False
Training samples: 1683
Testing samples: 297


In [4]:
# Define the model architecture for each modality
import torch
import torch.nn as nn

class ClinicalModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ClinicalModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.fc(x)

class CNAModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CNAModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.fc(x)

class mRNAModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(mRNAModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.fc(x)


In [5]:
# Define the combined model
class CombinedModel(nn.Module):
    def __init__(self, clinical_dim, cna_dim, mrna_dim, hidden_dim, output_dim):
        super(CombinedModel, self).__init__()
        # Sub-models for each modality
        self.clinical_model = ClinicalModel(clinical_dim, hidden_dim)
        self.cna_model = CNAModel(cna_dim, hidden_dim)
        self.mrna_model = mRNAModel(mrna_dim, hidden_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(3*hidden_dim, hidden_dim),  # Concatenate outputs
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()  # Binary classification
        )
        
    def forward(self, clinical, cna, mrna):
        # Get embeddings from each modality
        clinical_emb = self.clinical_model(clinical)
        cna_emb = self.cna_model(cna)
        mrna_emb = self.mrna_model(mrna)
        
        # Concatenate embeddings
        combined_emb = torch.cat((clinical_emb, cna_emb, mrna_emb), dim=1)
        
        # Final classification
        output = self.classifier(combined_emb)
        return output


In [7]:
# Initialize the model, loss function, and optimizer and train the model

from sklearn.metrics import accuracy_score

# Hyperparameters
clinical_dim = dataset.clinical_data.shape[1]
cna_dim = dataset.cna_data.shape[1]
mrna_dim = dataset.mrna_data.shape[1]
hidden_dim = 128
output_dim = 1
learning_rate = 0.0001
num_epochs = 100

# Initialize model, loss function, and optimizer
model = CombinedModel(clinical_dim, cna_dim, mrna_dim, hidden_dim, output_dim)
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Early stopping parameters
patience = 10
best_val_loss = np.inf
no_improve_epochs = 0
best_model_weights = None

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    all_train_labels = []
    all_train_predictions = []

    # Training phase
    for clinical, cna, mrna, labels in train_dataloader:
        # Forward pass
        outputs = model(clinical, cna, mrna).squeeze()
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Save predictions and true labels for accuracy calculation
        predictions = (outputs > 0.5).float()
        all_train_labels.extend(labels.numpy())
        all_train_predictions.extend(predictions.numpy())

    train_loss /= len(train_dataloader)
    train_accuracy = accuracy_score(all_train_labels, all_train_predictions)

    # Validation phase
    model.eval()
    val_loss = 0.0
    all_val_labels = []
    all_val_predictions = []

    with torch.no_grad():
        for clinical, cna, mrna, labels in test_dataloader:
            outputs = model(clinical, cna, mrna).squeeze()
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            predictions = (outputs > 0.5).float()
            all_val_labels.extend(labels.numpy())
            all_val_predictions.extend(predictions.numpy())

    val_loss /= len(test_dataloader)
    val_accuracy = accuracy_score(all_val_labels, all_val_predictions)

    print(
        f"Epoch [{epoch + 1}/{num_epochs}]"
        f" - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}"
        f" - Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}"
    )

    # Check for improvement in validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve_epochs = 0
        best_model_weights = model.state_dict()  # Save best weights
    else:
        no_improve_epochs += 1

    # Early stopping
    if no_improve_epochs >= patience:
        print(f"Early stopping at epoch {epoch + 1}. Best validation loss: {best_val_loss:.4f}")
        break

# Restore the best model weights after early stopping
if best_model_weights:
    model.load_state_dict(best_model_weights)
    print("Best model weights restored.")



Epoch [1/100] - Train Loss: 0.5760, Train Accuracy: 0.7213 - Val Loss: 0.5162, Val Accuracy: 0.7508
Epoch [2/100] - Train Loss: 0.4727, Train Accuracy: 0.7748 - Val Loss: 0.5041, Val Accuracy: 0.7542
Epoch [3/100] - Train Loss: 0.3581, Train Accuracy: 0.8526 - Val Loss: 0.5133, Val Accuracy: 0.7542
Epoch [4/100] - Train Loss: 0.2042, Train Accuracy: 0.9352 - Val Loss: 0.5781, Val Accuracy: 0.7643
Epoch [5/100] - Train Loss: 0.0635, Train Accuracy: 0.9905 - Val Loss: 0.7023, Val Accuracy: 0.7576
Epoch [6/100] - Train Loss: 0.0152, Train Accuracy: 1.0000 - Val Loss: 0.8007, Val Accuracy: 0.7407
Epoch [7/100] - Train Loss: 0.0040, Train Accuracy: 1.0000 - Val Loss: 0.8800, Val Accuracy: 0.7340
Epoch [8/100] - Train Loss: 0.0020, Train Accuracy: 1.0000 - Val Loss: 0.9311, Val Accuracy: 0.7374
Epoch [9/100] - Train Loss: 0.0012, Train Accuracy: 1.0000 - Val Loss: 0.9861, Val Accuracy: 0.7306
Epoch [10/100] - Train Loss: 0.0008, Train Accuracy: 1.0000 - Val Loss: 1.0453, Val Accuracy: 0.7407

In [10]:
# Evaluate the model on the test set
model.eval()
test_loss = 0.0
all_labels = []
all_predictions = []

with torch.no_grad():
    for clinical, cna, mrna, labels in test_dataloader:
        # clinical, cna, mrna, labels = clinical.to('cuda'), cna.to('cuda'), mrna.to('cuda'), labels.to('cuda')       

        # Forward pass
        outputs = model(clinical, cna, mrna).squeeze()
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Save predictions for evaluation
        predictions = (outputs > 0.5).float()
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())

print(f"Test Loss: {test_loss / len(test_dataloader):.4f}")

calculate_metrics(all_labels, all_predictions)

Test Loss: 1.1002
Accuracy: 0.7441
F1 Score: 0.8410
AUC: 0.5858
AUPRC: 0.7845
Precision: 0.7882
Recall: 0.9013
Specificity: 0.2703
MCC: 0.2130
