In [1]:
import numpy as np
import pandas as pd
import torch
from skorch.net import NeuralNet
import os 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from survival_benchmark.python.modules.MultiSurv.multisurv import MultiSurv
from survival_benchmark.python.modules.MultiSurv.loss import Loss
from survival_benchmark.python.modules.MultiSurv.dataset import MultimodalDataset

In [None]:
data_location = '/Users/nja/Desktop/survival-benchmark/data/TARGET/CBioPortal/wt_target_2018_pub/processed_v6/WT_data_complete_modalities_preprocessed.csv'

In [224]:
import os
import random
import csv
import warnings
import pandas as pd 

import torch
from torch.utils.data import Dataset, DataLoader

from typing import List, Tuple

class MultimodalDataset(Dataset):
    """Dataset class for MultiSurv; Returns a dictionary where each key is a modality
    and the corresponding value is the tensor 
    """

    def __init__(self, data_path:str,label_path:str=None, modalities:List[str] = ['clinical','gex','mirna','cnv','meth','mut'], dropout:int=0, device:torch.device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')) -> None:
        super().__init__()

        self.data = pd.read_csv(data_path,index_col=0)
        if label_path:
            self.labels = pd.read_csv(label_path)
        else:
            try:
                self.labels = self.data[['OS_days','OS']]
            except KeyError:
                print("Survival event and time not available in data. Please provide a path to label file instead.")

    
        try:
            self.patient_ids = self.data['patient_id']
        except KeyError:
            print("patient_id not found in data, using index")
            self.patient_ids = self.data.index

        self.available_modalities = [m for m in modalities if any(self.data.columns.str.contains(m))]

        assert 0 <= dropout <= 1, '"dropout" must be in [0, 1].'
        self.dropout = dropout
    
        # assert all(any(self.data.columns.str.contains(m)) for m in modalities), "One or more modalities not present in the data"
        assert all(any(self.data.columns.str.contains(m)) for m in self.available_modalities), "One or more modalities not present in the data"
    
    def _get_modality(self, modality, patient_id):
        columns_to_subset = self.data.columns[self.data.columns.str.contains(modality)]
        subset = self.data.loc[patient_id,columns_to_subset]
        # return subset.to_numpy()
        if modality == 'clinical':
            # return torch.zeros(1)
            # TODO: add a transformation here for clinical -> tensor
            return subset.to_numpy()
        elif all(subset.isna()):
            print("error, found missing data")
            return self._set_missing_modality(subset)
        else:
            return torch.from_numpy(np.array(subset,dtype=np.float32))
    
    def _set_missing_modality(self,data,value:float=0.0):
        
        return torch.from_numpy(data.fillna(value).to_numpy())
    
    def _drop_data(self,data):
        
        # for clinical, multisurv only uses continous features for drop out

        # Drop data modality
        n_mod = len(self.available_modalities)
        modalities_to_drop = self.available_modalities
        modalities_to_drop.remove('clinical')
        if n_mod > 1:
            if random.random() < self.dropout:
                drop_modality = random.choice(modalities_to_drop)
                
                data[drop_modality] = torch.zeros_like(data[drop_modality])

        return data 
    
    def get_patient_dict(self,patient_id):
        time, event = self.labels.loc[patient_id]
        data = {}

        # Load selected patient's data
        for modality in self.available_modalities:
            data[modality] = self._get_modality(modality,patient_id)

        # Data dropout
        if self.dropout > 0:
            n_modalities = len([k for k in data])
            if n_modalities > 1:
                data = self._drop_data(data)

        return data, time, event

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        print(patient_id)
        data, time, event = self.get_patient_dict(patient_id)
        return data, (time,event)


In [225]:
dataset = MultimodalDataset(data_location)

patient_id not found in data, using index


In [226]:
multisurv_skorch = NeuralNet(
    module=MultiSurv,
    criterion=Loss,
    optimizer=torch.optim.Adam,
    module__data_modalities=['clinical'],
    module__n_output_intervals = torch.arange(0,30,1),
)

In [227]:
train = DataLoader(dataset)

In [228]:
for x,y in train:
    print(type(x))

TARGET-50-CAAAAQ
<class 'dict'>
TARGET-50-PAJNNR
<class 'dict'>
TARGET-50-PAKGZX
<class 'dict'>
TARGET-50-CAAAAM
<class 'dict'>
TARGET-50-CAAAAS
<class 'dict'>
TARGET-50-PAKFYV
<class 'dict'>
TARGET-50-PAKECR
<class 'dict'>
TARGET-50-PAJNBN
<class 'dict'>
TARGET-50-PAJMKN
<class 'dict'>
TARGET-50-PAKSCC
<class 'dict'>
TARGET-50-PAJNCZ
<class 'dict'>
TARGET-50-PAJPGY
<class 'dict'>
TARGET-50-PAJNZK
<class 'dict'>
TARGET-50-PAJLTI
<class 'dict'>
TARGET-50-PAJMKI
<class 'dict'>
TARGET-50-PAJPDN
<class 'dict'>
TARGET-50-PAKXWB
<class 'dict'>
TARGET-50-PAJNVX
<class 'dict'>
TARGET-50-PAJMFY
<class 'dict'>
TARGET-50-PAKKNS
<class 'dict'>
TARGET-50-PAJNYT
<class 'dict'>
TARGET-50-PAEBXA
<class 'dict'>
TARGET-50-PAKKSE
<class 'dict'>
TARGET-50-PAJPCM
<class 'dict'>
TARGET-50-PAEAFB
<class 'dict'>
TARGET-50-PAKGED
<class 'dict'>
TARGET-50-PAJNUP
<class 'dict'>
TARGET-50-PAKYFC
<class 'dict'>
TARGET-50-PAJPEW
<class 'dict'>
TARGET-50-CAAAAB
<class 'dict'>
TARGET-50-PAJMUF
<class 'dict'>
TARGET-5