In [None]:
%pylab inline
import pandas as pd
import numpy as np
import fastai
import torch
from pathlib import Path
import PIL
import tqdm
import os
import json
tqdm.monitor_interval = 0


from fastai.conv_learner import resnet34,resnext101, transforms_top_down, CropType, \
    tfms_from_model, ConvLearner, optim, T, Callback
from fastai.dataset import Denormalize, ImageData, FilesNhotArrayDataset, \
    ImageClassifierData, csv_source, parse_csv_labels, split_by_idx, read_dir, \
    FilesIndexArrayDataset, dict_source, FilesArrayDataset
from fastai.sgdr import TrainingPhase, DecayType
from lifelines.utils import concordance_index
from collections import defaultdict
from aixtras import *

In [None]:
torch.backends.cudnn.benchmark=True
torch.cuda.set_device(1)
torch.cuda.current_device()

torch.backends.cudnn.deterministic = True
torch.manual_seed(7)



In [None]:
LIVER_PATH = Path('/DATA/BIO/GDC/liver')
LIVER_SAMPLES = LIVER_PATH/"samples"
EXP_PATH = LIVER_PATH/"exp_deep"
EXP_MODEL_PATH = EXP_PATH/"models"
TRAIN_DIR = EXP_PATH/"train"
TEST_DIR = EXP_PATH/"test"
CSV_DATA = EXP_PATH/"records.csv"

for d in [EXP_PATH, EXP_MODEL_PATH]:
    if not d.exists():
        d.mkdir()

In [None]:
if not CSV_DATA.exists():
    print("build traing/val/test csv data")
    
    slides = pd.read_csv(LIVER_PATH/'slides.csv')
    slides = slides.loc[slides.sample_type_id.isin([1,11])]
    slide_level = 'level_1'
    samples_per_patient = 100
    split = 0.8
    val_split = 0.8

    slide_info = defaultdict(dict)

    def pull_tiles(slides, patient_id, num_tiles, slide_level):
        slide_fns = []
        tiles = []

        # get list of candidate samples
        for i, row in slides.loc[slides.submitter_id == patient_id].iterrows():
            slide_name = row.slide_file_name
            sfp = LIVER_SAMPLES/row.slide_file_name.upper()/slide_level
            slide_fns = list(sfp.iterdir())

        num_samples = len(slide_fns)
        tiles = list(np.random.choice(slide_fns, size=min(num_tiles,num_samples), replace=False))

        return tiles



    def build_tiles(patients, dsname, folder):
        records = []
        folder.mkdir(parents=True, exist_ok=True)
        for p in tqdm.tqdm_notebook(patients):
            tiles = pull_tiles(slides, p, samples_per_patient, slide_level)
            for i, tile_fn in enumerate(tiles):
                base_name = '%s_%04d.tiff' % (p, i)
                dest_tile = folder/base_name
                os.symlink(tile_fn, dest_tile)
                records.append({
                    'patient': p,
                    'dsname': dsname,
                    'event_time': slides.loc[slides.submitter_id == p, 'days_proxy'].iloc[0],
                    'event_type': slides.loc[slides.submitter_id == p, 'event_observed'].iloc[0],
                    'src_tile': tile_fn,
                    'dest_tile': dest_tile
                })
        return records



    # create event time, drop any nulls, create event observed
    slides['days_proxy'] = slides.days_to_death.fillna(slides.days_to_last_follow_up)
    slides = slides.loc[slides.days_proxy.notnull()]
    slides['event_observed'] = True
    slides.loc[slides.days_to_last_follow_up.notnull(),'event_observed'] = False    
    slides['event_observed'] = slides['event_observed'].astype(int)

    # filter tumor only
    slides = slides.loc[slides.sample_type_id == 1]

    #create censor label

    patients = list(set(slides.submitter_id))
    num_patients = len(patients)
    train_val_split = int(split * num_patients)
    train_split = int(val_split * train_val_split)

    random_patients = np.random.permutation(patients)
    train_patients = random_patients[0:train_split]
    valid_patients = random_patients[train_split:train_val_split]
    test_patients = random_patients[train_val_split:]

    # convert days_proxy to int for softmax
    slides['days_proxy'] = slides.days_proxy.astype(int)


    # arrange the sample data
    train_records = build_tiles(train_patients, 'train', TRAIN_DIR)
    valid_records = build_tiles(valid_patients, 'valid', TRAIN_DIR)
    test_records = build_tiles(test_patients, 'test', TEST_DIR)

    csv_data = pd.DataFrame(train_records + valid_records + test_records)
    csv_data.to_csv(CSV_DATA, index=False)
else:
    print("csv data already built")

csv_data = pd.read_csv(CSV_DATA)
csv_data.event_time = csv_data.event_time // 10
# remember largest possible survival day
t_max = int(csv_data.event_time.max()) # np.int64 will fuck up torch
print(t_max)

In [None]:
class ImageSurvivalData(ImageClassifierData):
    @classmethod
    def from_suvival_csv(cls, path, folder, csv_fname, bs=64, tfms=(None,None),
                         test_name=None, skip_header=True, num_workers=8, 
                         fname_col='fname', time_col='event_time', type_col='event_type', suffix=None):
        assert not (tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
        assert not (os.path.isabs(folder)), "folder needs to be a relative path"
        
        csv_data = pd.read_csv(csv_fname)
        csv_data[time_col] = csv_data[time_col]
        t_max = csv_data[time_col].max()
        classes = list(range(t_max+1))
        num_classes = len(classes)
 
        train_val_data = csv_data.loc[csv_data.dsname.isin(['train','valid'])]
        test_data = csv_data.loc[csv_data.dsname == 'test']
        
        fnames = train_val_data[fname_col]
        test_fnames = test_data[fname_col]
        
        def get_one_hot(targets, nb_classes):
            return np.eye(nb_classes)[np.array(targets).reshape(-1)]

        evt_times = train_val_data[time_col].values
        evt_type = train_val_data[type_col].values
        
        #y = np.concatenate([evt_times, evt_type[:,None]], axis=1)
        #import pdb; pdb.set_trace()
        y = train_val_data[[time_col, type_col]].values
        y_test = test_data[[time_col, type_col]].values
        
        val_idxs = train_val_data.dsname == 'valid'
        
        ((val_fnames,trn_fnames),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(fnames), y)
         
        class FilesSurvivalArrayDataset(FilesArrayDataset):
            def get_c(self): return int(t_max + 1)

            @property
            def is_multi(self): return True
    
        f = FilesSurvivalArrayDataset
        datasets = cls.get_ds(f, (trn_fnames,trn_y), (val_fnames,val_y), tfms,
                               path=path, test=np.array(test_fnames))
        
        datasets[4].y = y_test
        datasets[5].y = y_test
        
        return cls(path, datasets, bs, num_workers, classes=classes)
    

In [None]:
f_model = resnext101

def get_data(sz, bs):
    tfms = tfms_from_model(f_model, sz, aug_tfms=transforms_top_down)
    return ImageSurvivalData.from_suvival_csv(
        EXP_PATH, 'train', CSV_DATA, test_name='test', 
        tfms=tfms, bs=bs, fname_col='dest_tile'
    )

sz=256
bs=8
md = get_data(sz, bs)

In [None]:
from fastai.layers import AdaptiveConcatPool2d,Flatten
from torch.nn import BatchNorm1d,Dropout,ReLU,Linear,Sequential,Hardtanh,Softmax

feat = 4096

layers = [AdaptiveConcatPool2d(), Flatten()]
layers += [BatchNorm1d(feat),
            Dropout(p=0.5), 
            Linear(in_features=feat, out_features=256), 
            ReLU(), 
            BatchNorm1d(256),
            Dropout(p=0.5), 
            Linear(in_features=256, out_features=len(md.classes)),
            Softmax()]
head_relu = Sequential(*layers)

In [None]:
learn = ConvLearner.pretrained(f_model, md,custom_head=head_relu)
#learn = ConvLearner.pretrained(f_model, md)

num_evt_types = 1
def custom_loss(preds, target):
    evt_times = target[:,0]
    evt_types = target[:,1]
    l1_loss, pairwise_loss = deephit_loss(preds, evt_times, evt_types, t_max+1, num_evt_types)
    b1 = 0.50
    b2 = 0.50
    return b1 * pairwise_loss + b2 * l1_loss 


class ConcordanceIndex(Callback):
    def __init__(self, ):
        self.reset()

    def on_epoch_begin(self, metrics):
        self.reset()

    def on_epoch_end(self, metrics):
        ci = concordance_index(
            np.array(self.evt_times), 
            np.array(self.preds), 
            np.array(self.evt_types)
        )
        print('ci: ', ci, len(self.preds), len(self.evt_times))
        self.reset()

    def reset(self):
        self.preds = []
        self.evt_times = []
        self.evt_types = []
        self.mcount = 0
       
    def concordance_metric(self, preds, target):
        #import pdb; pdb.set_trace()
        self.evt_times += list(target[:,0])
        self.evt_types += list(target[:,1])
        self.preds += list(np.argmax(preds, axis=1))
        self.mcount += 1
        return 0.0 #self.mcount    


cindex = ConcordanceIndex()
callbacks = [cindex]
learn.crit = custom_loss
learn.metrics = [cindex.concordance_metric] # accuracy stuff gets confused by last column of evt_type
#learn.opt_fn = optim.Adam


In [None]:
learn.children[-1:]

In [None]:
if True:
    learn.save('tmp')
    lrf=learn.lr_find()
    learn.sched.plot(0)
    learn.load('tmp')
    cindex.reset()

In [None]:
lr = 0.001

In [None]:
lr = 0.005
learn.fit(lr, 5, cycle_len=1, 
          use_clr_beta = (40,20,0.95,0.85), 
          best_save_name='liver_deephit_1_best',
          callbacks = callbacks) 

learn.save('liver_deephit_1')

In [None]:
learn.load('liver_deephit_1')
cindex.reset()

In [None]:
learn.unfreeze()
lr = 0.00003
learn.fit(lr, 1, cycle_len=10, 
          use_clr_beta = (40,20,0.95,0.85), 
          best_save_name='liver_deephit_2_best',
          callbacks = callbacks) 
learn.save('liver_deephit_2')

In [None]:
print('done')

In [None]:
learn.load('liver_deephit_2')
cindex.reset()

In [None]:
multi_preds, targs = learn.TTA()
preds = np.mean(multi_preds, 0)
y_pred = np.argmax(preds, 1)
evt_type = targs[:, 1]
evt_time = targs[:, 0]
concordance_index(evt_time, y_pred, evt_type)

In [None]:
multi_preds, targs = learn.TTA(is_test=True)
preds = np.mean(multi_preds, 0)
y_pred = np.argmax(preds, 1)
evt_type = targs[:, 1]
evt_time = targs[:, 0]
concordance_index(evt_time, y_pred, evt_type)

In [None]:
learn.unfreeze()
lr = 0.0001
learn.fit(lr, 1, cycle_len=300, 
          use_clr_beta = (40,20,0.95,0.85), 
          best_save_name='liver_deephit_3_best',
          callbacks = callbacks) 
learn.save('liver_deephit_3')

In [None]:
learn.load('liver_deephit_3')

multi_preds, targs = learn.TTA()
preds = np.mean(multi_preds, 0)
y_pred = np.argmax(preds, 1)
evt_type = targs[:, 1]
evt_time = targs[:, 0]
print(concordance_index(evt_time, y_pred, evt_type))

multi_preds, targs = learn.TTA(is_test=True)
preds = np.mean(multi_preds, 0)
y_pred = np.argmax(preds, 1)
evt_type = targs[:, 1]
evt_time = targs[:, 0]
print(concordance_index(evt_time, y_pred, evt_type))