## !Use Google Colab:
https://colab.research.google.com/drive/1rt6yOA0omDJZ1wvGXLTa8UbiEU1QXd3e#scrollTo=4oludV5HDQDH

##### Code to fit DeepSurv with pretrained Autoencoder.

How to use this code:

1. Upload data sets to content pane. To run the code without modifications the names should be:
* exprs_intersect.csv for gene data
* merged_imputed_pData.csv for clinical data

2. In addition, upload the pretrained models to the content pane. The models should be obtained the models-folder from from this notebook: https://colab.research.google.com/drive/1kOvHaFqIrlJQg6Zgy395caf_GaKFTz-W?usp=sharing.

3. Adapt model parameters and modelling process parameters in MODEL_CONFIG:
* To run the DeepSurv with Autoencoder representations and with or without clinical data, please refer to chunk 5
* To perform nested resampling set 'do_nested_resampling' in MODEL_CONFIG to True
* To train final model set 'refit' in MODEL_CONFIG to True
* Adapt model hyperparameters to your liking

4. Run the Notebook

---



In [None]:
### Chunk 1
# Installing and laoding packages
!pip install lifelines
!pip install scikit-learn==1.5.2
!pip install scikit-survival==0.23.1
!pip install --upgrade sympy
import numpy as np
import copy
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
import torch
from lifelines.utils import concordance_index
from sklearn.utils.validation import check_X_y, check_is_fitted
import logging
from sklearn.model_selection import train_test_split, LeaveOneGroupOut, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils import check_random_state
from sksurv.util import Surv
import os
import pickle
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder




[0mCollecting sympy
  Using cached sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Using cached sympy-1.13.3-py3-none-any.whl (6.2 MB)
[0mInstalling collected packages: sympy
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.5.1+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-nvrtc-cu12 12.5.82 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cuda-runtime-cu12==12.4.127; platform_system =

In [None]:

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class DeepSurvNet(nn.Module):
    """
    PyTorch based neural network architecture designed for survival prediction.
    This network consists of fully connected layers with ReLU activation,
    dropout for regularization, and a final layer that outputs a single
    hazard prediction value.
    """
    def __init__(self, n_features, hidden_layers=[32, 16], dropout=0.2):
        super().__init__()
        layers = []
        prev_size = n_features
        self.model = None

        # Build hidden layers
        for size in hidden_layers:
            layers.extend([
                nn.Linear(prev_size, size),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_size = size

        # Output layer (1 node for hazard prediction)
        layers.append(nn.Linear(prev_size, 1, bias=False))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class DeepSurvModel(BaseEstimator, RegressorMixin):
    """
    Implementation of the DeepSurv model that integrates
    with scikit-learn, specifying  configurable architecture,
    training procedures, and evaluation metrics.

    The model includes:
    - Customizable neural network architecture
    - Mini-batch training with early stopping
    - CPU/GPU support
    - Concordance index evaluation
    - Compatibility with scikit-learn's cross-validation and pipeline features
    - Reproducible training through seed control

    The model follows scikit-learn's estimator interface by implementing
    fit(), predict(), get_params() and set_params() methods.
    """
    def __init__(self, n_features=None, hidden_layers=[16, 16], dropout=0.5,
                 learning_rate=0.01, device='cpu', random_state=123,
                 batch_size=128, num_epochs=100, patience=15):
        self.n_features = n_features
        self.hidden_layers = hidden_layers
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.device = device if torch.cuda.is_available() and device == 'cuda' else 'cpu'
        self.random_state = random_state
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.patience = patience

        torch.manual_seed(random_state)
        np.random.seed(random_state)

        self.scaler = StandardScaler()
        self.model = None
        self.is_fitted_ = False
        self.training_history_ = {'train_loss': [], 'val_loss': []}
        self.n_features_in_ = None

    def fit(self, X, y):
      # Input validation for X and y
      X, y = check_X_y(X, y, accept_sparse=True)

      self.n_features_in_ = X.shape[1]
      print(self.n_features_in_)
      self.init_network(self.n_features_in_)
      self.model.to(self.device)

      train_dataset_, val_dataset_ = self._prepare_data(X, y, val_split = 0.1)
      train_loader_ = DataLoader(train_dataset_, batch_size=self.batch_size, shuffle=True)
      val_loader = DataLoader(val_dataset_, batch_size = 32, shuffle = True)

      best_val_loss = float('inf')
      best_model_state = None
      counter = 0.0
      for epoch in range(self.num_epochs):
          self.model.train()
          epoch_loss_ = 0.0
          n_batches_ = 0
          for X_batch, time_batch, event_batch in train_loader_:
              loss = self._train_step(X_batch, time_batch, event_batch)
              epoch_loss_ += loss
              n_batches_ += 1
          avg_train_loss = epoch_loss_ / n_batches_
          self.training_history_['train_loss'].append(avg_train_loss)

          # Validation
          self.model.eval()
          val_loss = 0.0
          with torch.no_grad():
              for X_batch, time_batch, event_batch in val_loader:
                  val_loss += self._eval_step(X_batch, time_batch, event_batch)

          val_loss = val_loss / len(val_loader)
          self.training_history_['val_loss'].append(val_loss)

          # Save best model
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              best_model_state = copy.deepcopy(self.model.state_dict())
              counter = 0
          else:
              counter += 1

          if counter > self.patience:
              print(f"Early stopping at epoch {epoch+1}")
              break

      # Restore best model
      if best_model_state is not None:
          self.model.load_state_dict(best_model_state)

      self.is_fitted_ = True
      return self

    def predict(self, X):
        # Predict risk scores for new data
        check_is_fitted(self, 'is_fitted_')
        if isinstance(X, pd.DataFrame):
            X = X.values
        X = torch.FloatTensor(X).to(self.device)
        self.model.eval()
        with torch.no_grad():
            risk_scores = self.model(X).cpu().numpy()
        return risk_scores.flatten()

    def score(self, X, y):
        # Calculate concordance index
        check_is_fitted(self, 'is_fitted_')
        preds = self.predict(X)
        return self.c_index(-preds, y)

    def get_params(self, deep=True):
        # Return model parameters
        return {
            "n_features": self.n_features,
            "hidden_layers": self.hidden_layers,
            "dropout": self.dropout,
            "learning_rate": self.learning_rate,
            "device": self.device,
            "random_state": self.random_state,
            "batch_size": self.batch_size,
            "num_epochs": self.num_epochs,
            "patience": self.patience
        }

    def set_params(self, **parameters):
        # Set model parameters
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

    def clone(self):
        super(self).clone()

    def _prepare_data(self, X, y, val_split = 0.1):
        # Split data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_split, random_state=42)
        X_scaled_train = X_train
        times_train = np.ascontiguousarray(y_train['time']).astype(np.float32)
        event_field_train = 'status' if 'status' in y_train.dtype.names else 'event'
        events_train = np.ascontiguousarray(y_train[event_field_train]).astype(np.float32)
        X_tensor_train = torch.FloatTensor(X_scaled_train).to(self.device)
        time_tensor_train = torch.FloatTensor(times_train).to(self.device)
        event_tensor_train = torch.FloatTensor(events_train).to(self.device)

        X_scaled_val = X_val
        times_val = np.ascontiguousarray(y_val['time']).astype(np.float32)
        event_field_val = 'status' if 'status' in y_val.dtype.names else 'event'
        events_val = np.ascontiguousarray(y_val[event_field_val]).astype(np.float32)
        X_tensor_val = torch.FloatTensor(X_scaled_val).to(self.device)
        time_tensor_val = torch.FloatTensor(times_val).to(self.device)
        event_tensor_val = torch.FloatTensor(events_val).to(self.device)

        return TensorDataset(X_tensor_train, time_tensor_train, event_tensor_train), TensorDataset(X_tensor_val, time_tensor_val, event_tensor_val)

    def _negative_log_likelihood(self, risk_pred, times, events):
        # Calculate negative log-likelihood loss
        _, idx = torch.sort(times, descending=True)
        risk_pred = risk_pred[idx]
        events = events[idx]
        log_risk = risk_pred
        risk = torch.exp(log_risk)
        cumsum_risk = torch.cumsum(risk, dim=0)
        log_cumsum_risk = torch.log(cumsum_risk + 1e-10)
        event_loss = events * (log_risk - log_cumsum_risk)
        return -torch.mean(event_loss)

    def _train_step(self, X, times, events):
        # Perform one training step
        self.optimizer.zero_grad()
        risk_pred = self.model(X)
        loss = self._negative_log_likelihood(risk_pred, times, events)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def _eval_step(self, X, times, events):
        # Evaluate model on validation data
        risk_pred = self.model(X)
        loss = self._negative_log_likelihood(risk_pred, times, events)
        return loss.item()

    def _check_early_stopping(self, counter):
        if len(self.training_history_['val_loss']) < 2:
            return 0.0

        if self.training_history_['val_loss'][-1] < self.training_history_['val_loss'][-2]:
            counter = 0.0
        else:
            counter += 1.0
        return counter

    def c_index(self, risk_pred, y):
        # Calculate concordance index
        if not isinstance(y, np.ndarray):
            y = y.detach().cpu().numpy()
        event_field = 'status' if 'status' in y.dtype.names else 'event'
        time = y['time']
        event = y[event_field]
        if not isinstance(risk_pred, np.ndarray):
            risk_pred = risk_pred.detach().cpu().numpy()
        if np.isnan(risk_pred).all():
            return np.nan
        return concordance_index(time, risk_pred, event)

    def init_network(self, n_features):
        # Initialize the neural network and optimizer
        self.model = DeepSurvNet(n_features=n_features, hidden_layers=self.hidden_layers, dropout=self.dropout).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

In [None]:
def _get_survival_subset(y, indices):
    """
    Extracts a subset of the survival dataset

    Args:
        y (np.ndarray): Structured array containing survival data with fields 'time' and 'status' (or 'event').
        indices: Indices of the subset.

    Returns:
        np.ndarray: Array containing the prognostic endpoint (BCR, MONTH_TO_BCR)
    """
    subset = np.empty(len(indices), dtype=y.dtype)
    event_field = 'status' if 'status' in y.dtype.names else 'event'
    subset[event_field] = y[event_field][indices]
    subset['time'] = y['time'][indices]
    return subset

def _aggregate_results(results):
    """
    Aggregates nested cross-validation results.

    Args:
        results (list of dict): A list of dictionaries containing results from each CV fold.

    Returns:
        dict: A dict with aggr. infos
    """
    scores = [res['test_score'] for res in results]
    if np.isnan(scores).all():
        logger.warning(f"Found only NaN values in CV-results: {scores}")
        mean_score, std_score = np.nan, np.nan
    else:
        mean_score = np.nanmean(scores)
        std_score = np.nanstd(scores)

    logger.info(f"Aggregated results:")
    logger.info(f"Mean score: {mean_score:.3f} ± {std_score:.3f}")
    logger.info(f"Individual scores: {scores}")

    return {
        'mean_score': mean_score,
        'std_score': std_score,
        'fold_results': results
    }

def nested_resampling(estimator, X, y, groups, param_grid, monitor = None, ss = GridSearchCV, outer_cv = LeaveOneGroupOut(), inner_cv = LeaveOneGroupOut(), scoring = None):
    """
    Performs nested resampling using Leave-One-Group-Out.

    Args:
        estimator (sklearn estimator/sklearn pipeline): The base model/pipeline
        X (pd.DataFrame): Feature matrix.
        y (np.ndarray): Survival outcome data.
        groups (array): Group labels for each sample
        param_grid (dict): Hyperparameter grid for the inner cross-validation.
        monitor (optional): Monitoring parameter for early stopping (default: None).
        ss (class): The search strategy class (default: GridSearchCV).
        outer_cv (sklearn CV splitter): Outer cross-validation splitter (default: LeaveOneGroupOut).
        inner_cv (sklearn CV splitter): Inner cross-validation splitter (default: LeaveOneGroupOut).
        scoring (optional): Scoring function for model evaluation (default: None).

    Returns:
        dict: Aggregated results from nested cross-validation.
    """
    logger.info("Starting nested resampling...")
    logger.info(f"Data shape: X={X.shape}, groups={len(np.unique(groups))} unique")

    outer_results = []

    for i, (train_idx, test_idx) in enumerate(outer_cv.split(X, y, groups)):
        logger.info(f"\nOuter fold {i+1}")

        X_train = X.iloc[train_idx]
        X_test = X.iloc[test_idx]
        y_train = _get_survival_subset(y, train_idx)
        y_test = _get_survival_subset(y, test_idx)
        train_groups = groups[train_idx] if groups is not None else None

        test_cohort = groups[test_idx][0] if groups is not None else None
        logger.info(f"Test cohort: {test_cohort}")

        inner_gcv = ss(estimator, param_grid, cv = inner_cv, refit = True, n_jobs=4, verbose = 2)
        if monitor is not None:
            inner_results = inner_gcv.fit(X_train, y_train, groups = train_groups, model__monitor = monitor)
            logger.info(
                f"number of iterations early stopping: {inner_results.best_estimator_.named_steps['model'].n_estimators_}")

        else:
            inner_results = inner_gcv.fit(X_train, y_train, groups = train_groups)

        inner_cv_results = inner_results.cv_results_
        inner_best_params = inner_results.best_params_

        outer_model = inner_results.best_estimator_
        test_score = outer_model.score(X_test, y_test)

        logger.info(f"Best parameters: {inner_best_params}")
        logger.info(f"Test score: {test_score:.3f}")

        outer_results.append({
            'test_cohort': test_cohort,
            'test_score': test_score,
            'best_params': inner_best_params,
            'inner_cv_results': inner_cv_results
        })

    return _aggregate_results(outer_results)


class ModellingProcess():
    """
    Class to handle the full modelling process for python models with sklearn-interface. Includes data preparation, cross-validation,
    hyperparameter tuning, model fitting, and result saving.
    """
    def __init__(self) -> None:
        self.outer_cv = LeaveOneGroupOut()
        self.inner_cv = LeaveOneGroupOut()
        self.ss = GridSearchCV
        self.pipe = None
        self.cmplt_model = None
        self.cmplt_pipeline = None
        self.nrs = None
        self.X = None
        self.y = None
        self.groups = None
        self.path = None
        self.fname_cv = None

    def prepare_survival_data(self, pdata):
        status = pdata['BCR_STATUS'].astype(bool).values
        time = pdata['MONTH_TO_BCR'].astype(float).values
        y = Surv.from_arrays(
            event=status,
            time=time,
            name_event='status',
            name_time='time'
        )
        return y

    def prepare_data(self, config):
        X = pd.read_csv('/content/exprs_intersect.csv', index_col=0)
        pdata = pd.read_csv('/content/merged_imputed_pData.csv', index_col=0)

        self.y = self.prepare_survival_data(pdata)
        self.groups = np.array([idx.split('.')[0] for idx in X.index])

        if config.get('clinical_covs', None) is not None:
                logger.info('Found clinical data specification')
                pdata['AGE'] = pd.to_numeric(pdata['AGE'], errors='coerce')
                clin_data = pdata.loc[:, config['clinical_covs']]
                cat_cols = clin_data.select_dtypes(exclude=['number']).columns
                num_cols = clin_data.select_dtypes(exclude=['object']).columns
                clin_data_cat = clin_data.loc[:, cat_cols]
                if config.get('requires_ohenc', True) is True:
                    ohc = OneHotEncoder()
                    clin_data_cat = ohc.fit_transform(clin_data_cat)
                    clin_data_cat = pd.DataFrame.sparse.from_spmatrix(clin_data_cat, columns=ohc.get_feature_names_out()).set_index(X.index)
                clin_data_num = clin_data.loc[:, num_cols]

                if config.get('only_pData', False) is not False:
                    logger.info('Only uses pData')
                    self.X = pd.concat([clin_data_cat, clin_data_num], axis = 1)
                else:
                    self.X = pd.concat([clin_data_cat, clin_data_num, X], axis = 1)
        else:
          self.X = X


    def do_modelling(self, pipeline_steps, config):
        """
        Executes the complete modeling process, including pipeline creation, nested resampling, and final model fitting.

        Args:
            pipeline_steps (list): List of (name, transformer) tuples for creating the pipeline --> objects need to adhere to scikit learn interface /API.
            config (dict): Configuration for the modeling process, including parameters for cross-validation,
                           hyperparameter tuning, and result saving.

        Returns:
            tuple: Nested resampling results, final model, and complete, final pipeline.
        """
        self._set_seed()

        if config.get("params_mp", None) is not None:
            self.set_params(config['params_mp'])

        if config.get("path", None) is None or config.get("fname_cv", None) is None:
            logger.warning("Didn't get sufficient path info for saving cv-results")
        else:
            self.path = config['path']
            self.fname_cv = config['fname_cv']

        err, mes = self._check_modelling_prerequs(pipeline_steps)
        if err:
            logger.error("Requirements setup error: %s", mes)
            raise Exception(mes)
        else:
            self.pipe = Pipeline(pipeline_steps)

        param_grid, monitor, do_nested_resampling, refit_hp_tuning = self._get_config_vals(config)

        try:
            logger.info("Start model training...")
            logger.info(f"Input data shape: X={self.X.shape}")

            if do_nested_resampling:
                logger.info("Nested resampling...")
                self.nrs = nested_resampling(self.pipe, self.X, self.y, self.groups, param_grid, monitor, self.ss, self.outer_cv, self.inner_cv)
                if (self.fname_cv is not None) and (self.path is not None):
                    self.save_results(self.path, self.fname_cv, model = None, cv_results = self.nrs, pipe = None)
        except Exception as e:
            logger.error(f"Error during nested resampling: {str(e)}")
            raise

        if refit_hp_tuning:
            try:
                logger.info("Do HP Tuning for complete model; refit + set complete model")
                self.fit_cmplt_model(param_grid)
                if (self.fname_cv is not None) and (self.path is not None):
                    self.save_results(self.path, self.fname_cv, model = self.cmplt_model, cv_results = None, pipe = self.cmplt_pipeline)
            except Exception as e:
                logger.error(f"Error during complete model training: {str(e)}")
                raise
        elif refit_hp_tuning is False and do_nested_resampling is False:
            logger.info("Fit complete model wo. HP tuning (on default params)")
            self.cmplt_model = self.pipe.fit(self.X, self.y)

        return self.nrs, self.cmplt_model, self.cmplt_pipeline


    def fit_cmplt_model(self, param_grid, monitor=None):
        """
        Performs hyperparameter tuning and fits the final model on all of group A.

        Args:
            param_grid (dict): Parameter grid for GridSearchCV.
            monitor (optional): Additional monitor object for evaluation during training.

        Returns:
            tuple: The best model and the complete resampling result.
        """
        logger.info("Do HP Tuning for complete model")
        res = self.ss(
            estimator=self.pipe,
            param_grid=param_grid,
            cv=self.outer_cv,
            n_jobs=1,  # Changed from -1 to 1
            verbose=2,
            refit=True
        )
        if monitor is not None:
            res.fit(self.X, self.y, groups=self.groups, model__monitor=monitor)
        else:
            res.fit(self.X, self.y, groups=self.groups)
        self.resampling_cmplt = res
        self.cmplt_pipeline = res.best_estimator_
        self.cmplt_model = res.best_estimator_.named_steps['model']
        return self.cmplt_model, res


    def save_results(self, path, fname, model=None, cv_results=None, pipe=None):
        """
        Saves the model, cross-validation results, and pipeline to the specified directories.

        Args:
            path (str): Directory path to save the results.
            fname (str): File name for saving the results.
            model (optional): Trained model to save as a .pth file.
            cv_results (optional): Cross-validation results to save as a .csv file.
            pipe (optional): Pipeline to save as a pickle file.

        Returns:
            None
        """
        if model is None:
            logger.warning("Won't save any model, since its not provided")
        else:
            model_dir = os.path.join(path, 'model')
            os.makedirs(model_dir, exist_ok=True)
            model.model.to(torch.device('cpu'))
            torch.save(model.model, os.path.join(model_dir, f"{fname}.pth"))
            logger.info(f"Saved model to {model_dir}")

        if cv_results is None:
            logger.warning("Won't save any cv results, since its not provided")
        else:
            results_dir = os.path.join(path, 'results')
            os.makedirs(results_dir, exist_ok=True)
            results_file = os.path.join(results_dir, f"{fname}_cv.csv")
            pd.DataFrame(cv_results).to_csv(results_file)
            logger.info(f"Saved CV results to {results_file}")


    def _check_modelling_prerequs(self, pipeline_steps):
        """
        Checks whether the necessary prerequisites for the modeling process are met (data is prepared + model exists in pipeline).

        Args:
            pipeline_steps (list): List of (name, transformer) tuples representing the steps in the pipeline.

        Returns:
            tuple: A boolean if error exists.
        """
        err = False
        mes = ""
        if self.X is None or self.y is None:
            mes = mes + "1) Please call prepare_data() with your preferred config or set X, y, and groups"
            err = True
        if not any('model' in tup for tup in pipeline_steps):
            mes = mes + "2) Caution! Your pipeline must include a step named 'model' for the model"
            err = True
        return err, mes


    def _get_config_vals(self, config):
        """
        Extracts config values from the provided modelling dictionary.

        Args:
            config (dict): Configuration dictionary with keys such as 'params_cv', 'monitor',
                        'do_nested_resampling', and 'refit'.

        Returns:
            tuple: Contains the following extracted values:
                - param_grid (dict or None): Parameter grid for cross-validation.
                - monitor (object or None): Optional monitor object for early stopping.
                - do_nested_resampling (bool): Nested resampling should be performed?
                - refit_hp_tuning (bool): Refit the model with hyperparameter tuning?
        """
        if config.get("params_cv", None) is None:
            logger.warning("No param grid for (nested) resampling detected - will fit model with default HPs on complete data")
            return None, False, False, False
        if config.get('monitor', None) is None:
            logger.info("No additional monitoring detected")
        return config['params_cv'], config.get('monitor', None), config.get('do_nested_resampling', True), config.get('refit', True)


    def set_params(self, params):
        for key, value in params.items():
            setattr(self, key, value)

    def _set_seed(self, seed = 1234):
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        global random_state
        random_state = check_random_state(seed)


In [None]:
from sklearn.feature_selection import SelectFromModel, SelectorMixin
from sklearn.base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone, is_classifier
import joblib
import os
import sys
from sklearn.base import BaseEstimator, TransformerMixin
import pandas as pd
import torch

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from sklearn.model_selection import train_test_split
from typing import Callable
import torch.nn.functional as F
# Set random seed for reproducibility
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

"""
The code in this chunk is based on the autoencoder implementation
in this repository https://github.com/phcavelar/pathwayae.
"""
class MLP(nn.Module):
    def __init__(
            self,
            input_dim:int,
            hidden_dims:list[int],
            output_dim:int,
            nonlinearity:Callable,
            dropout_rate:float=0.5,
            bias:bool=True,
            ):
        super().__init__()
        in_dims = [input_dim] + hidden_dims
        out_dims = hidden_dims + [output_dim]

        self.layers = nn.ModuleList([nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(in_dims, out_dims)])
        self.nonlinearity = nonlinearity
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        for layer in self.layers[:-1]:
            x = self.dropout(self.nonlinearity(layer(x)))
        return self.layers[-1](x)

    def layer_activations(self, x:torch.Tensor) -> list[torch.Tensor]:
        # To allow for activation normalisation
        activations = [x]
        for layer in self.layers[:-1]:
            activations.append(self.dropout(self.nonlinearity(layer(activations[-1]))))
        return activations[1:] + [self.layers[-1](activations[-1])]

class NopLayer(nn.Module):
    def __init__(
            self,
            *args,
            **kwargs,
            ):
        super().__init__()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return x

    def update_temperature(self,*args,**kwargs) -> None:
        pass

    def layer_activations(self,*args,**kwargs) -> list[torch.Tensor]:
        return []

class Autoencoder(nn.Module):
    def __init__(
            self,
            input_dim:int=None,
            hidden_dims:list[int]=[128],
            encoding_dim:int=64,
            nonlinearity=F.relu,
            final_nonlinearity=lambda x:x,
            dropout_rate:float=0.5,
            bias:bool=True,
            ):
        super().__init__()
        if input_dim is None:
            raise ValueError("Must specify input dimension before initialising the model")
        try:
            len(hidden_dims)
        except TypeError:
            hidden_dims = [hidden_dims]

        self.encoder = MLP(input_dim, hidden_dims, encoding_dim, nonlinearity, dropout_rate, bias)
        self.decoder = MLP(encoding_dim, hidden_dims[-1::-1], input_dim, nonlinearity, dropout_rate, bias)
        self.final_nonlinearity = final_nonlinearity

    def encode(self,x:torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

    def decode(self,x:torch.Tensor) -> torch.Tensor:
        return self.final_nonlinearity(self.decoder(x))

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat

    def layer_activations(self,x:torch.Tensor) -> list[torch.Tensor]:
        # To allow for activation normalisation
        encoder_activations = self.encoder.layer_activations(x)
        decoder_activations = self.decoder.layer_activations(encoder_activations[-1])
        return encoder_activations + decoder_activations

    def get_feature_importance_matrix(self) -> torch.Tensor:
        with torch.no_grad():
            feature_importance_matrix = self.encoder.layers[0].weight.T
            for layer in self.encoder.layers[1:]:
                feature_importance_matrix = torch.matmul(feature_importance_matrix, layer.weight.T)
        return feature_importance_matrix.detach()

In [None]:
class FoldAwareAE(BaseEstimator, TransformerMixin):
    """
    Custom transformer class, based on the transformer class from scikit-learn API.
    Integrates an Autoencoder model for feature transformation and
    loads a pretrained autoencoder model based on the cohorts present in the input dataset.

    Attributes:
        all_cohorts (list): List of all known cohort names.
        model (Autoencoder): Instance of the autoencoder model.
        testing (bool): Whether AE is used during testing or training.
    """
    def __init__(self, testing = False):
        self.all_cohorts = ['Atlanta_2014_Long', 'Belfast_2018_Jain', 'CamCap_2016_Ross_Adams',
                            'CancerMap_2017_Luca', 'CPC_GENE_2017_Fraser', 'CPGEA_2020_Li',
                            'DKFZ_2018_Gerhauser', 'MSKCC_2010_Taylor', 'Stockholm_2016_Ross_Adams']
        self.model = None
        self.testing = testing

    def fit(self, X, y=None):
        """
        Dynamically loads a pretrained autoencoder model based on the cohorts present in the dataset.

        Args:
            X (DataFrame): Input dataset with cohort info in the index.
            y (ignored): Included for compatibility with scikit-learn.

        Returns:
            self: The fitted instance of the FoldAwareAE class.
        """
        root = '/content'
        if self.testing is False:
            cohort_names = X.index.to_series().str.split('.').str[0]
            unique_cohort_names = cohort_names.unique()
            model_path = ''
            for c in self.all_cohorts:
                if c not in unique_cohort_names:
                    if len(model_path) > 0:
                        model_path +=  "_"
                    model_path += c
            if model_path == '':
                model_path = 'pretrnd_cmplt'
            model_path = os.path.join(root, model_path)
        else:
            model_path = 'pretrnd_cmplt'
            model_path = os.path.join(root, model_path)

        self.model = Autoencoder(input_dim=len(X.columns))

        self.model.load_state_dict(torch.load(model_path + '.pth'))
        self.model.eval()
        return self

    def transform(self, X, y = None):
        """
        Transforms the input data into its corresponding latent representation using the encoder.
        Implemented to adhere to scikit-learn API.

        Args:
            X (DataFrame): Input dataset to be transformed.
            y (ignored): Included for compatibility with scikit-learn.

        Returns:
            DataFrame: Latent representation of the input data with original index of X.
        """
        X_t = torch.FloatTensor(X.values).to('cpu')
        ls = self.model.encoder(X_t).detach().cpu().numpy()
        ls = pd.DataFrame(ls, index=X.index)
        return ls

    def fit_transform(self, X, y=None, **fit_params):
        """
        Combines the fit and transform steps. Implemented to adhere to scikit-learn API.

        Args:
            X (DataFrame): Input dataset to be fitted and transformed.
            y (ignored): Included for compatibility with scikit-learn.
            **fit_params: Additional params for fit method.

        Returns:
            DataFrame: Latent representation of the input data after fitting + transforming.
        """
        self.fit(X, y, **fit_params)
        return self.transform(X)

In [None]:
# chunk 5
DATA_CONFIG = {
    'use_pca': False,
    'pca_threshold': 0.85,
    'gene_type': 'intersection',
    'use_imputed': True,
    'select_random' : False,
    'use_cohorts': False,
    'requires_ohenc' : True,
    'only_pData': False,
    'clinical_covs' : ["AGE", "TISSUE", "GLEASON_SCORE", 'PRE_OPERATIVE_PSA'] # remove if only Autoencoder without clinical data is to be trained
}

mp = ModellingProcess()
mp.prepare_data(DATA_CONFIG)

MODEL_CONFIG = {
    'params_cv'  : {
        'model__hidden_layers': [[256, 128], [256]],
        'model__learning_rate': [0.00001],
        'model__batch_size': [64],
        'model__num_epochs': [500],
        'model__dropout': [0.2],
        'model__device': ['cuda']
    },
    'refit': True,
    'do_nested_resampling': True,
    'path' : 'content',
    'fname_cv' : 'deepsurv_autoencoder_pdata'
}

# If Autoencoder with clinical data is to be trained
pdata_cols = ['TISSUE_FFPE','TISSUE_Fresh_frozen', 'TISSUE_Snap_frozen', 'AGE', 'GLEASON_SCORE', 'PRE_OPERATIVE_PSA'] # remove if only Autoencoder without clinical data is to be trained
exprs_cols =  list(set(mp.X.columns) - set(pdata_cols))
exprs_cols = sorted(exprs_cols)

# If  Autoencoder without clinical data is to be trained
# exprs_cols = mp.X.columns
# pdata_cols = []

ae = FoldAwareAE()
preprocessor = ColumnTransformer(
    transformers=[
        ('feature_selection', ae, exprs_cols),  # Apply feature selection
        ('other_features', 'passthrough', pdata_cols)         # Pass through other columns
    ]
)

# Define the pipeline
pipe_steps = [
    ('preprocessor', preprocessor),
    ('model', DeepSurvModel())]


In [None]:
mp.do_modelling(pipe_steps, MODEL_CONFIG)

Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 117
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 20
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 88
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 61
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 103
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 116
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 32
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 101
Fitting 8 folds for each of 2 candidates, totalling 16 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70




Early stopping at epoch 104
Fitting 9 folds for each of 2 candidates, totalling 18 fits


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 117
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   5.6s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 20
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.1s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 88
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   4.4s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 61
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   3.5s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 104
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   4.8s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 46
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   2.3s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 32
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.7s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 68
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=   6.5s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 104
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256, 128], model__learning_rate=1e-05, model__num_epochs=500; total time=  10.1s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 153
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   6.5s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 31
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.4s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 25
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.3s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 17
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.1s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 103
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   4.7s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 116
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   4.6s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 38
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   1.8s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 101
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   4.7s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70
Early stopping at epoch 123
[CV] END model__batch_size=64, model__device=cuda, model__dropout=0.2, model__hidden_layers=[256], model__learning_rate=1e-05, model__num_epochs=500; total time=   5.3s


  self.model.load_state_dict(torch.load(model_path + '.pth'))


70




Early stopping at epoch 20


({'mean_score': 0.6214761924447018,
  'std_score': 0.0824873974682325,
  'fold_results': [{'test_cohort': 'Atlanta_2014_Long',
    'test_score': 0.6676300578034682,
    'best_params': {'model__batch_size': 64,
     'model__device': 'cuda',
     'model__dropout': 0.2,
     'model__hidden_layers': [256, 128],
     'model__learning_rate': 1e-05,
     'model__num_epochs': 500},
    'inner_cv_results': {'mean_fit_time': array([15.6319668 ,  7.31603545]),
     'std_fit_time': array([6.53201347, 2.94259126]),
     'mean_score_time': array([0.17088866, 0.08237451]),
     'std_score_time': array([0.05691288, 0.0621777 ]),
     'param_model__batch_size': masked_array(data=[64, 64],
                  mask=[False, False],
            fill_value=999999),
     'param_model__device': masked_array(data=['cuda', 'cuda'],
                  mask=[False, False],
            fill_value='?',
                 dtype=object),
     'param_model__dropout': masked_array(data=[0.2, 0.2],
                  mask=[Fa

In [None]:
# 1. Process only the expression data using the autoencoder
exprs_data = pd.read_csv('/content/exprs_intersect.csv', index_col=0)
ae = FoldAwareAE()
ae.fit(exprs_data)
exprs_encoded = ae.transform(exprs_data)

# 2. Prepare clinical data
pdata = pd.read_csv('/content/merged_imputed_pData.csv', index_col=0)
clin_data = pdata.loc[:, DATA_CONFIG['clinical_covs']]

# Separate categorical and numerical columns
cat_cols = clin_data.select_dtypes(exclude=['number']).columns
num_cols = clin_data.select_dtypes(exclude=['object']).columns

# Apply one-hot encoding for categorical data if required
if DATA_CONFIG.get('requires_ohenc', True):
    ohc = OneHotEncoder()
    clin_data_cat = ohc.fit_transform(clin_data[cat_cols])
    clin_data_cat = pd.DataFrame.sparse.from_spmatrix(
        clin_data_cat,
        columns=ohc.get_feature_names_out(),
        index=exprs_data.index
    )

clin_data_num = clin_data[num_cols]

# 3. Combine encoded expression data with clinical data
X_combined = pd.concat([clin_data_cat, clin_data_num, exprs_encoded], axis=1)

# 4. Train the DeepSurv model with the combined data
model_params = {k.replace('model__', ''): v[0] for k, v in MODEL_CONFIG['params_cv'].items()}
model = DeepSurvModel(**model_params)
model.fit(X_combined, mp.y)


In [None]:
# 1. Define configurations
DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'use_cohorts': False,
    'requires_ohenc': True,
    'only_pData': False,
    'clinical_covs' : ["AGE", "TISSUE", "GLEASON_SCORE", 'PRE_OPERATIVE_PSA']
}

MODEL_CONFIG = {
    'params_cv': {
        'model__hidden_layers': [[256, 128]],
        'model__learning_rate': [0.00001],
        'model__batch_size': [64],
        'model__num_epochs': [500],
        'model__dropout': [0.2],
        'model__device': ['cuda']
    },
    'refit': False,
    'do_nested_resampling': False,
    'path': '/content/saved_model',
    'fname_cv': 'deepsurv_ae_final'
}

def breslow_baseline_hazard(model, X, times, events):
    """
    Computes the Breslow estimator for the cumulative baseline hazard
    """
    log_risk = model.predict(X)
    risk = np.exp(log_risk)
    risk = np.clip(risk, 1e-10, None)  # Ensure numerical stability

    order = np.argsort(times)
    sorted_times = times[order]
    sorted_events = events[order]
    sorted_risk = risk[order]

    unique_event_times = np.unique(sorted_times[sorted_events == 1])

    bhaz = []
    at_risk_sum = np.zeros_like(unique_event_times)
    event_count = np.zeros_like(unique_event_times)

    for i, t in enumerate(unique_event_times):
        at_risk = sorted_risk[sorted_times >= t]
        at_risk_sum[i] = at_risk.sum()
        event_count[i] = np.sum((sorted_times == t) & (sorted_events == 1))

    bhaz = event_count / np.maximum(at_risk_sum, 1e-8)
    cbhaz = np.cumsum(bhaz)

    return unique_event_times, bhaz, cbhaz

def load_and_predict_survival(X_test, save_dir="/content/saved_model"):
    """
    Load the model and make survival predictions
    """
    clin_data = X_test.iloc[:, :6]
    exprs_data = X_test.iloc[:, 6:]

    print("Clinical Features:", clin_data.shape[1])
    print("Expression Features:", exprs_data.shape[1])

    model_state = torch.load(os.path.join(save_dir, "deep_surv_ae_state.pth"))
    model = DeepSurvModel(**model_state['model_params'])
    model.init_network(70)
    model.model.load_state_dict(model_state['model_state'])
    model.model.eval()

    ae = FoldAwareAE()
    ae.model = Autoencoder(input_dim=13214)
    ae.model.load_state_dict(torch.load(os.path.join(save_dir, "autoencoder_state.pth")))
    ae.model.eval()

    X_encoded = ae.transform(exprs_data)
    X_combined = pd.concat([clin_data, X_encoded], axis=1)

    with torch.no_grad():
        log_risk = model.predict(X_combined)
        risk = np.exp(log_risk)
        risk = np.clip(risk, 1e-10, None)

    times = model_state['unique_event_times']
    cbhaz = model_state['cum_baseline_hazard']

    surv_list = [np.exp(-lam_0_t * risk) for lam_0_t in cbhaz]
    surv = np.vstack(surv_list).T
    mean_surv = np.mean(surv, axis=0)

    return times, mean_surv, surv

# 2. Prepare data
exprs_data = pd.read_csv('/content/exprs_intersect.csv', index_col=0)
ae = FoldAwareAE()
ae.fit(exprs_data)
exprs_encoded = ae.transform(exprs_data)

pdata = pd.read_csv('/content/merged_imputed_pData.csv', index_col=0)
clin_data = pdata[DATA_CONFIG['clinical_covs']]

cat_cols = clin_data.select_dtypes(exclude=['number']).columns
num_cols = clin_data.select_dtypes(exclude=['object']).columns
if DATA_CONFIG.get('requires_ohenc', True):
    ohc = OneHotEncoder()
    clin_data_cat = ohc.fit_transform(clin_data[cat_cols])
    clin_data_cat = pd.DataFrame.sparse.from_spmatrix(
        clin_data_cat,
        columns=ohc.get_feature_names_out(),
        index=exprs_data.index
    )
clin_data_num = clin_data[num_cols]
clin_data_processed = pd.concat([clin_data_cat, clin_data_num], axis=1)

X_combined = pd.concat([clin_data_processed, exprs_encoded], axis=1)
print("Final feature dimensions:", X_combined.shape[1])

status = pdata['BCR_STATUS'].astype(bool).values
time = pdata['MONTH_TO_BCR'].astype(float).values
y = Surv.from_arrays(event=status, time=time, name_event='status', name_time='time')

model_params = {k.replace('model__', ''): v[0] for k, v in MODEL_CONFIG['params_cv'].items()}
model = DeepSurvModel(**model_params)
model.fit(X_combined, y)

times, bhaz, cbhaz = breslow_baseline_hazard(model=model, X=X_combined, times=y['time'], events=y['status'])

save_dir = "/content/saved_model"
os.makedirs(save_dir, exist_ok=True)

model_state = {
    'model_params': model_params,
    'model_state': model.model.state_dict(),
    'unique_event_times': times,
    'cum_baseline_hazard': cbhaz
}

torch.save(model_state, os.path.join(save_dir, "deep_surv_ae_state.pth"))
torch.save(ae.model.state_dict(), os.path.join(save_dir, "autoencoder_state.pth"))
print("Model saved.")

X_complete = pd.concat([clin_data_processed, exprs_data], axis=1)
times, mean_surv, all_surv = load_and_predict_survival(X_complete)

plt.figure(figsize=(10,6))
plt.step(times, mean_surv, where='post', color='blue', label='Mean Survival')
plt.fill_between(times, np.percentile(all_surv, 25, axis=0), np.percentile(all_surv, 75, axis=0), alpha=0.2, step='post', color='blue')
plt.xlabel('Time (months)')
plt.ylabel('Survival Probability')
plt.title('Predicted Survival Curves')
plt.grid(True)
plt.legend()
plt.show()


In [None]:
def predict_survival_curves(exprs_test, pdata_test, save_dir="/content/saved_model"):
    """
    Makes survival predictions for a test dataset
    """
    # 1. Prepare clinical data
    clin_data = pdata_test[DATA_CONFIG['clinical_covs']]

    # Separate categorical and numerical columns
    cat_cols = ['TISSUE']  # Only TISSUE is categorical
    num_cols = ['AGE', 'GLEASON_SCORE', 'PRE_OPERATIVE_PSA']

    # Manual one-hot encoding for TISSUE
    clin_data_cat = pd.DataFrame(index=clin_data.index)
    clin_data_cat['TISSUE_FFPE'] = 0
    clin_data_cat['TISSUE_Fresh_frozen'] = 0
    clin_data_cat['TISSUE_Snap_frozen'] = 0

    # Set the appropriate column to 1 based on TISSUE value
    for idx, tissue in clin_data['TISSUE'].items():
        col_name = f'TISSUE_{tissue}'
        if col_name in clin_data_cat.columns:
            clin_data_cat.loc[idx, col_name] = 1

    # Numerical data
    clin_data_num = clin_data[num_cols]

    # Combine categorical and numerical data
    clin_data_processed = pd.concat([clin_data_cat, clin_data_num], axis=1)

    # 2. Load model states
    model_state = torch.load(os.path.join(save_dir, "deep_surv_ae_state.pth"))

    # 3. Initialize and load DeepSurv model
    model = DeepSurvModel(**model_state['model_params'])
    input_dim = 64 + clin_data_processed.shape[1]  # AE output + clinical features
    model.init_network(input_dim)
    model.model.load_state_dict(model_state['model_state'])
    model.model.eval()

    # 4. Load and apply the autoencoder
    ae = FoldAwareAE()
    ae.model = Autoencoder(input_dim=exprs_test.shape[1])
    ae.model.load_state_dict(torch.load(os.path.join(save_dir, "autoencoder_state.pth")))
    ae.model.eval()

    # 5. Transform expression data using the autoencoder
    exprs_encoded = ae.transform(exprs_test)

    # 6. Combine encoded expression features with clinical features
    X_combined = pd.concat([clin_data_processed, exprs_encoded], axis=1)

    # 7. Compute risk score
    with torch.no_grad():
        log_risk = model.predict(X_combined)
        risk = np.exp(log_risk)
        risk = np.clip(risk, 1e-10, None)

    # 8. Compute survival function
    times = model_state['unique_event_times']
    cbhaz = model_state['cum_baseline_hazard']

    # Compute survival for each patient
    surv_curves = np.array([np.exp(-cbhaz * r) for r in risk])

    # 9. Create DataFrame
    result_df = pd.DataFrame({'time': times})

    # Add a column for each patient
    for i, patient_id in enumerate(exprs_test.index):
        result_df[f'patient_{patient_id}'] = surv_curves[i]

    return result_df

# Example usage:
# Load new test data
exprs_test = pd.read_csv('/content/intersect_genes_test_cohort1_low_risk.csv', index_col=0)
pdata_test = pd.read_csv('/content/low_risk_pData_test_cohort1.csv', index_col=0)

# Make predictions
survival_curves = predict_survival_curves(exprs_test, pdata_test)

# Save predictions
os.makedirs("/content/survival_data", exist_ok=True)
survival_curves.to_csv("/content/survival_data/predicted_survival_curves.csv", index=False)

# Plot survival curves for all patients
plt.figure(figsize=(12,8))
for patient in survival_curves.columns[1:]:  # First column is 'time'
    plt.step(survival_curves['time'], survival_curves[patient],
             where='post', alpha=0.3, label=patient)

plt.xlabel('Time (months)')
plt.ylabel('Survival Probability')
plt.title('Predicted Survival Curves for All Patients')
plt.grid(True)
if len(survival_curves.columns) <= 11:
    plt.legend()
plt.show()