In [1]:
import numpy as np
import pandas as pd
import torch
from skorch.net import NeuralNet
from sklearn.preprocessing import OrdinalEncoder
from hyperband import HyperbandSearchCV
from sklearn.utils.fixes import loguniform
from scipy.stats import uniform
from sklearn.metrics import make_scorer
from sklearn.model_selection import StratifiedKFold
import os 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from skorch.helper import SliceDataset, SliceDict

In [3]:
from survival_benchmark.python.utils.utils import StratifiedSurvivalKFold

In [4]:
from survival_benchmark.python.modules.MultiSurv.multisurv import MultiSurv, MultiSurvModel
from survival_benchmark.python.modules.MultiSurv.loss import Loss
from survival_benchmark.python.modules.MultiSurv.dataset_benchmark import MultimodalDataset

In [5]:
data_wt = '/Users/nja/Desktop/survival-benchmark/data/TARGET/CBioPortal/wt_target_2018_pub/processed_v6/WT_data_complete_modalities_preprocessed.csv'
data_nbl = '/Users/nja/Desktop/survival-benchmark/data/TARGET/CBioPortal/nbl_target_2018_pub/processed_v4/NBL_data_complete_modalities_preprocessed.csv'

In [6]:
dataset = MultimodalDataset(data_wt)
# datasetnbl = MultimodalDataset(data_nbl)

patient_id not found in data, using index


In [7]:
multisurv_skorch = MultiSurvModel(
    module=MultiSurv,
    criterion=Loss,
    optimizer=torch.optim.Adam,
    module__data_modalities=dataset.input_size,
    module__output_intervals = torch.arange(0,21,1),
    criterion__aux_criterion = None,
    criterion__is_multimodal = len(dataset.input_size)>1,
    max_epochs=1
)

In [8]:
param_spaces = [
    {
        "lr": loguniform(0.0001, 0.01),
    },]
ms_loss_scorer = make_scorer(
        score_func=multisurv_skorch.ms_loss,
        greater_is_better=False,
        breaks = torch.arange(0,21,1)
    )

In [9]:
grid = HyperbandSearchCV(
                        estimator=multisurv_skorch,
                        param_distributions=param_spaces[0],
                        resource_param="max_epochs",
                        scoring=ms_loss_scorer,
                        cv=StratifiedSurvivalKFold(),
                        random_state=42,
                        refit=False,
                        max_iter=1,
                        n_jobs=1,  # TODO: Change if multiple GPUs?
                    )

In [10]:
X_sl = SliceDataset(dataset,idx=0)
y_sl = SliceDataset(dataset,idx=1)

In [11]:
grid.fit(X_sl,y_sl)

  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3617[0m        [32m0.3478[0m  135.6153
  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3621[0m        [32m0.3475[0m  118.2235
  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3459[0m        [32m0.3446[0m  118.1394
  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3646[0m        [32m0.3450[0m  134.1435
  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3605[0m        [32m0.3470[0m  120.0637




HyperbandSearchCV(cv=StratifiedSurvivalKFold(n_splits=5, random_state=None, shuffle=False),
                  error_score='raise',
                  estimator=<class 'survival_benchmark.python.modules.MultiSurv.multisurv.MultiSurvModel'>[uninitialized](
  module=<class 'survival_benchmark.python.modules.MultiSurv.multisurv.MultiSurv'>,
  module__data_modalities={'clinical': {'categorical': [2, 4, 7]...
                  eta=3, iid=True, max_iter=1, min_iter=1, n_jobs=1,
                  param_distributions={'lr': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7fb2206a9940>},
                  pre_dispatch='2*n_jobs', random_state=42, refit=False,
                  resource_param='max_epochs', return_train_score=False,
                  scoring=make_scorer(ms_loss, greater_is_better=False, breaks=tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20])),
                  skip_last=0, verbose=0)

In [13]:
multisurv_skorch.initialize()

<class 'survival_benchmark.python.modules.MultiSurv.multisurv.MultiSurvModel'>[initialized](
  module_=MultiSurv(
    (clinical_submodel): ClinicalNet(
      (embedding_layers): ModuleList(
        (0): Embedding(2, 1)
        (1): Embedding(4, 2)
        (2): Embedding(7, 4)
      )
      (linear): Linear(in_features=8, out_features=256, bias=True)
      (embedding_dropout): Dropout(p=0.5, inplace=False)
      (bn_layer): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (output_layer): FC(
        (fc): Sequential(
          (0): Dropout(p=0.5, inplace=False)
          (1): Linear(in_features=256, out_features=512, bias=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (gex_submodel): FC(
      (fc): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=22263, out_features=65536, bias=True)
        (2): ReLU(inplace=True)
        (3): BatchNorm1d(65536, eps=1e-05, momentum=0.1, affine=True, track_ru

In [14]:
multisurv_skorch.predict(dataset)

tensor([[0.5067, 0.4996, 0.4966,  ..., 0.5018, 0.5082, 0.4918],
        [0.5062, 0.4994, 0.4970,  ..., 0.5016, 0.5079, 0.4918],
        [0.5067, 0.4991, 0.4966,  ..., 0.5013, 0.5082, 0.4913],
        ...,
        [0.5064, 0.4988, 0.4965,  ..., 0.5020, 0.5080, 0.4913],
        [0.5063, 0.4993, 0.4969,  ..., 0.5021, 0.5081, 0.4920],
        [0.5065, 0.4996, 0.4975,  ..., 0.5014, 0.5086, 0.4920]])

In [15]:
df = pd.read_csv(data_location,index_col=0).iloc[:,:6]

In [18]:
df.columns

Index(['clinical_AGE_IN_DAYS', 'OS_days', 'OS', 'clinical_SEX',
       'clinical_RACE', 'clinical_CLINICAL_STAGE'],
      dtype='object')

In [35]:
np.unique(a)

array([-2147483648,           0,           1,           2], dtype=int32)

In [47]:
len(np.unique(df['clinical_CLINICAL_STAGE']))

7

In [52]:
df[['clinical_CLINICAL_STAGE','clinical_SEX']].apply(lambda x: len(np.unique(x))).values

array([7, 2])

In [46]:
list(map(lambda x: (len(x),int(np.ceil(len(x)/2))),enc.categories_))

[(2, 1), (4, 2), (7, 4)]

In [4]:
# import os
# import random
# import csv
# import warnings
# import pandas as pd 

# import torch
# from torch.utils.data import Dataset, DataLoader

# from typing import List, Tuple

# class MultimodalDataset(Dataset):
#     """Dataset class for MultiSurv; Returns a dictionary where each key is a modality
#     and the corresponding value is the tensor 
#     """

#     def __init__(self, data_path:str,label_path:str=None, modalities:List[str] = ['clinical','gex','mirna','cnv','meth','mut'], dropout:int=0, device:torch.device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')) -> None:
#         super().__init__()

#         self.data = pd.read_csv(data_path,index_col=0)
#         if label_path:
#             self.labels = pd.read_csv(label_path)
#         else:
#             try:
#                 self.labels = self.data[['OS_days','OS']]
#             except KeyError:
#                 print("Survival event and time not available in data. Please provide a path to label file instead.")

    
#         try:
#             self.patient_ids = self.data['patient_id']
#         except KeyError:
#             print("patient_id not found in data, using index")
#             self.patient_ids = self.data.index

#         self.available_modalities = [m for m in modalities if any(self.data.columns.str.contains(m))]

#         assert 0 <= dropout <= 1, '"dropout" must be in [0, 1].'
#         self.dropout = dropout
    
#         # assert all(any(self.data.columns.str.contains(m)) for m in modalities), "One or more modalities not present in the data"
#         assert all(any(self.data.columns.str.contains(m)) for m in self.available_modalities), "One or more modalities not present in the data"
    
#     def _get_modality(self, modality, patient_id):
#         columns_to_subset = self.data.columns[self.data.columns.str.contains(modality)]
#         subset = self.data.loc[patient_id,columns_to_subset]
#         # return subset.to_numpy()
#         if modality == 'clinical':
#             # return torch.zeros(1)
#             # TODO: add a transformation here for clinical -> tensor
#             return subset.to_numpy()
#         elif all(subset.isna()):
#             print("error, found missing data")
#             return self._set_missing_modality(subset)
#         else:
#             return torch.from_numpy(np.array(subset,dtype=np.float32))
    
#     def _set_missing_modality(self,data,value:float=0.0):
        
#         return torch.from_numpy(data.fillna(value).to_numpy())
    
#     def _drop_data(self,data):
        
#         # for clinical, multisurv only uses continous features for drop out

#         # Drop data modality
#         n_mod = len(self.available_modalities)
#         modalities_to_drop = self.available_modalities
#         modalities_to_drop.remove('clinical')
#         if n_mod > 1:
#             if random.random() < self.dropout:
#                 drop_modality = random.choice(modalities_to_drop)
                
#                 data[drop_modality] = torch.zeros_like(data[drop_modality])

#         return data 
    
#     def get_patient_dict(self,patient_id):
#         time, event = self.labels.loc[patient_id]
#         data = {}

#         # Load selected patient's data
#         for modality in self.available_modalities:
#             data[modality] = self._get_modality(modality,patient_id)

#         # Data dropout
#         if self.dropout > 0:
#             n_modalities = len([k for k in data])
#             if n_modalities > 1:
#                 data = self._drop_data(data)

#         return data, time, event

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         patient_id = self.patient_ids[idx]
#         data, time, event = self.get_patient_dict(patient_id)
#         # target = np.array([f"{int(event)}|{time}"])
#         return data, (time, event)


In [6]:
dataset = MultimodalDataset(data_location)

patient_id not found in data, using index


In [9]:
dataset2 = MultimodalDataset(data_location, categorical_encoder=dataset.cat_encoder,cnv_encoder=dataset.cnv_encoder, scaler_test=dataset.scaler,mode='test')

patient_id not found in data, using index


In [None]:
dataset.input_size

{'clinical': {'categorical': [2, 4, 7], 'continuous': 1},
 'gex': 22263,
 'mirna': 1430,
 'cnv': {'categories': 5, 'length': 20641}}

In [32]:
len(dataset.input_size)

4

In [25]:
multisurv_skorch = MultiSurvModel(
    module=MultiSurv,
    criterion=Loss,
    optimizer=torch.optim.Adam,
    module__data_modalities=dataset.input_size,
    module__output_intervals = torch.arange(0,21,1),
    criterion__aux_criterion = None,
    criterion__is_multimodal = len(dataset.input_size)>1,
    max_epochs=1
)

In [8]:
multisurv_skorch.fit(dataset)

  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3529[0m        [32m0.0000[0m  131.9492


<class 'survival_benchmark.python.modules.MultiSurv.multisurv.MultiSurvModel'>[initialized](
  module_=MultiSurv(
    (clinical_submodel): ClinicalNet(
      (embedding_layers): ModuleList(
        (0): Embedding(2, 1)
        (1): Embedding(4, 2)
        (2): Embedding(7, 4)
      )
      (linear): Linear(in_features=8, out_features=256, bias=True)
      (embedding_dropout): Dropout(p=0.5, inplace=False)
      (bn_layer): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (output_layer): FC(
        (fc): Sequential(
          (0): Dropout(p=0.5, inplace=False)
          (1): Linear(in_features=256, out_features=512, bias=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (gex_submodel): FC(
      (fc): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=22263, out_features=65536, bias=True)
        (2): ReLU(inplace=True)
        (3): BatchNorm1d(65536, eps=1e-05, momentum=0.1, affine=True, track_ru

In [26]:
param_spaces = [
    {
        "lr": loguniform(0.0001, 0.01),
        "batch_size": [64, 128, 256],
    },]
ms_loss_scorer = make_scorer(
        score_func=multisurv_skorch.ms_loss,
        greater_is_better=False,
    )

In [27]:
grid = HyperbandSearchCV(
                        estimator=multisurv_skorch,
                        param_distributions=param_spaces[0],
                        resource_param="max_epochs",
                        scoring=ms_loss_scorer,
                        cv=StratifiedSurvivalKFold(),
                        random_state=42,
                        refit=False,
                        max_iter=1,
                        n_jobs=1,  # TODO: Change if multiple GPUs?
                    )

In [22]:
X_sl = SliceDataset(dataset,idx=0)
y_sl = SliceDataset(dataset,idx=1)

In [28]:
grid.fit(X_sl,y_sl)

  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3633[0m        [32m0.2174[0m  171.2530


TypeError: expected Tensor as element 0 in argument 0, but got int

In [None]:
for x,y in test:
    print(multisurv_skorch.predict_survival_function(x))
    # TODO: need to aggregate the probs, and create a df with output_intervals
    # TODO also save best model form train 

In [16]:
y = [(1,2),(3,4),(5,6)]
np.concatenate(y,0)

array([1, 2, 3, 4, 5, 6])

In [14]:
apply_()

ValueError: too many values to unpack (expected 2)

In [11]:
multisurv_skorch.fit(dataset)

  epoch    train_loss    valid_loss       dur
-------  ------------  ------------  --------
      1        [36m0.3581[0m        [32m0.0000[0m  157.1996
      2        [36m0.0882[0m        0.0000  193.0398


<class 'survival_benchmark.python.modules.MultiSurv.multisurv.MultiSurvModel'>[initialized](
  module_=MultiSurv(
    (clinical_submodel): ClinicalNet(
      (embedding_layers): ModuleList(
        (0): Embedding(2, 1)
        (1): Embedding(4, 2)
        (2): Embedding(7, 4)
      )
      (linear): Linear(in_features=8, out_features=256, bias=True)
      (embedding_dropout): Dropout(p=0.5, inplace=False)
      (bn_layer): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (output_layer): FC(
        (fc): Sequential(
          (0): Dropout(p=0.5, inplace=False)
          (1): Linear(in_features=256, out_features=512, bias=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (gex_submodel): FC(
      (fc): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=22263, out_features=65536, bias=True)
        (2): ReLU(inplace=True)
        (3): BatchNorm1d(65536, eps=1e-05, momentum=0.1, affine=True, track_ru

In [284]:
enc = OrdinalEncoder()
df = pd.DataFrame({'a':['A','B','C'],'b':['Dog','Cat','Bull']})
enc.fit_transform(df)

array([[0., 2.],
       [1., 1.],
       [2., 0.]])

In [288]:
torch.tensor(enc.fit_transform(df),dtype=torch.int)

tensor([[0, 2],
        [1, 1],
        [2, 0]], dtype=torch.int32)