In [None]:
%pylab inline
import pandas as pd
import numpy as np
import fastai
import torch
import torch.nn.functional as F
import torch.nn as nn
from pathlib import Path
import PIL
import tqdm
import os
import json
import shutil
tqdm.monitor_interval = 0

from fastai.conv_learner import resnet34,resnext101, transforms_top_down, CropType, \
    tfms_from_model, ConvLearner, optim, T, Callback, RandomRotateZoom, to_gpu, \
    Learner, BasicModel
from fastai.dataset import Denormalize, ImageData, FilesNhotArrayDataset, \
    ImageClassifierData, csv_source, parse_csv_labels, split_by_idx, read_dir, \
    FilesIndexArrayDataset, dict_source, FilesArrayDataset, DataLoader, ModelData, ImageData
from fastai.layers import AdaptiveConcatPool2d,Flatten
from torch.nn import BatchNorm1d,Dropout,ReLU,Linear,Sequential,Hardtanh,Softmax
from fastai.column_data import PassthruDataset
from fastai.sgdr import TrainingPhase, DecayType
from lifelines.utils import concordance_index
from collections import defaultdict
from aixtras import *
from bio.liver import build_liver_csv_data
from sklearn.preprocessing import MinMaxScaler



In [None]:
torch.backends.cudnn.benchmark=True
torch.cuda.set_device(0)
torch.cuda.current_device()

devs = [0,1,2,3]

In [None]:
build_liver_csv_data?

In [None]:
LIVER_PATH = Path('/DATA/BIO/GDC/liver')
LIVER_SAMPLES = LIVER_PATH/"samples"
EXP_PATH = LIVER_PATH/"exp_deep_play"
EXP_MODEL_PATH = EXP_PATH/"models"
TRAIN_DIR = EXP_PATH/"train"
TEST_DIR = EXP_PATH/"test"
CSV_DATA = EXP_PATH/"records.csv"
TB_LOGS = EXP_PATH/'logs'

force_rebuild = True
scale_x_cols = False
    
for d in [EXP_PATH, EXP_MODEL_PATH, TB_LOGS]:
    if not d.exists():
        d.mkdir()

In [None]:
from tensorboardX import SummaryWriter

sw = SummaryWriter('/DATA/tblogs/runs/expplay')


In [None]:
csv_data = build_liver_csv_data(
    LIVER_PATH, 
    EXP_PATH, 
    TRAIN_DIR, 
    TEST_DIR, 
    progress=tqdm.tqdm_notebook,
    force_rebuild=force_rebuild,
    samples_per_patient=50,
    val_split=0.5
)

# remember largest possible survival day
t_max = int(csv_data.event_time.max()) # np.int64 will fuck up torch
print(t_max)

dframes = {}
for dsname, df in csv_data.groupby('dsname'):
    dframes[dsname] = df


feat_cols = ['age_at_diagnosis']
n_cols = len(feat_cols)

if scale_x_cols:
    scaler = MinMaxScaler().fit(dframes['train'][feat_cols].values.astype(float))
    for k in dframes:
        dframes[k].loc[:,feat_cols] = scaler.transform(dframes[k].loc[:,feat_cols])

In [None]:
#dframes['train']

In [None]:
class FilesAndColumnsSurvivalDataset(FilesArrayDataset):
    def __init__(self, fnames, x_cols, y_times_and_observed, t_max, transform, path, jitter=None):
        self.y=y_times_and_observed
        self.t_max = t_max
        self.x_cols = x_cols
        self.n_cols = x_cols.shape[0]
        self.fnames = fnames
        self.jitter = jitter
        super().__init__(fnames, self.y, transform, path)
        assert(len(self.fnames)==len(self.y)==len(self.x_cols))
        self.c = self.get_c()
        
    def get_y(self, i): return self.y[i]
    
    def get_c(self):
        return self.t_max
    
    def get_x(self, i):
        x_img = super().get_x( i)
        x_cols = self.x_cols[i]
        return [x_cols, x_img]
    
    def get1item(self, idx):
        x,y = self.get_x(idx),self.get_y(idx)
        x_cols, x_img = x
        if self.jitter:
            jitter = np.random.randn(*x_cols.shape) * self.jitter + 1.0
            x_cols *= jitter
        else:
            x_cols = x_cols.astype(float)
        x_cols = np.log(x_cols + 0.01) - 6.0
        x_img, y = self.get(self.transform, x_img, y)
        x_flat = np.concatenate([x_cols.flatten(), x_img.flatten()])
        #return x_flat, y
        return x_cols, y


In [None]:
def get_data(sz, bs, tfms, num_workers=8):
    def df_to_ds(key, tfm, jitter=None):
        df = dframes[key]
        ds = FilesAndColumnsSurvivalDataset(
                df.dest_tile.values, 
                df[feat_cols].values, 
                df[['event_time','event_type']].values, 
                t_max, 
                tfm,
                EXP_PATH,
                jitter=jitter
        )
        return ds

    
    datasets = [
        df_to_ds('train', tfms[0]), #, jitter=0.03),
        df_to_ds('valid', tfms[1]),
        df_to_ds('train', tfms[1]),
        df_to_ds('valid', tfms[0]), #, jitter=0.03),
        df_to_ds('test', tfms[1]),
        df_to_ds('test', tfms[0]) #, jitter=0.03)
    ]
    
    classes = range(t_max+1)
    model_data = ImageData(EXP_PATH, datasets, bs, num_workers, classes)
    return model_data


num_evt_types = 1
def custom_loss(preds, target):
    step += 1
    #import pdb; pdb.set_trace()
    evt_times = target[:,0]
    evt_types = target[:,1]
    l1_loss, pairwise_loss = deephit_loss(preds, evt_times, evt_types, t_max+1, num_evt_types)
    b1 = 0.50
    b2 = 0.50
    return b1 * pairwise_loss + b2 * l1_loss 


class ConcordanceIndex(Callback):
    def __init__(self, ):
        self.reset()

    def on_epoch_begin(self, metrics):
        self.reset()

    def on_epoch_end(self, metrics):
        ci = concordance_index(
            np.array(self.evt_times), 
            np.array(self.preds), 
            np.array(self.evt_types)
        )
        #import pdb; pdb.set_trace()
        print('ci: ', ci, len(self.preds), len(self.evt_times))
        self.reset()

    def reset(self):
        self.preds = []
        self.evt_times = []
        self.evt_types = []
        self.mcount = 0
       
    def concordance_metric(self, preds, target):
        #import pdb; pdb.set_trace()
        self.evt_times += list(target[:,0])
        self.evt_types += list(target[:,1])
        self.preds += list(np.argmax(preds, axis=1))
        self.mcount += 1
        return 0.0 #self.mcount    

In [None]:
f_model = resnet34
bs = 32
sz = 256
tfms = tfms_from_model(f_model, sz, aug_tfms=transforms_top_down) 
md = get_data(sz, bs, tfms)


In [None]:
#md.trn_ds[0:3]

In [None]:
class SplitCols(nn.Module):
    def __init__(self, img_model, cols_model, sz, n_chan, n_cols):
        super().__init__()
        self.sz = sz
        self.n_cols = n_cols
        self.n_chan = n_chan
        self.img_model = img_model
        self.cols_model = cols_model
        self.final = nn.Sequential(
            #nn.Linear(in_features=2*len(md.classes), out_features=len(md.classes)),
            nn.Softmax()
        )

    def forward(self, x):
        #import pdb; pdb.set_trace()
        bs = x.shape[0]
        x_cols = x[:,0:self.n_cols].astype(float)
        #x_img = x[:,self.n_cols:].reshape((bs,self.n_chan,sz,sz))
        #img_result = self.img_model(x_img)
        cols_result = cols_model(x_cols)
        #x_combo = torch.cat((cols_result, img_result), dim=1)
        #x_combo = cols_result + img_result
        x_combo = cols_result #+ img_result
        return cols_result
        #return self.final(x_combo)


feat = 1024

layers = [AdaptiveConcatPool2d(), Flatten()]
layers += [
    BatchNorm1d(feat),
    Dropout(p=0.5), 
    Linear(in_features=feat, out_features=256), 
    ReLU(), 
    BatchNorm1d(256),
    Dropout(p=0.5), 
    Linear(in_features=256, out_features=len(md.classes))
]
custom_head = Sequential(*layers)

hidden = 256
cols_model = nn.Sequential(
    Linear(in_features=n_cols, out_features=hidden), 
    ReLU(), 
    BatchNorm1d(hidden),
    Linear(in_features=hidden, out_features=len(md.classes)),
    Softmax()
)

In [None]:
m = BasicModel(to_gpu(cols_model), 'colnet')

learn = Learner(md, m)
cindex = ConcordanceIndex()
callbacks = [cindex]
learn.crit = custom_loss
learn.metrics = [cindex.concordance_metric]


In [None]:
x, y = learn.data.trn_ds[0:32]
learn.models.model(T(x.astype(float))).shape

In [None]:
learn.lr_find()
learn.sched.plot(0)

In [None]:
cindex.reset()
learn.fit(0.01, 
          1, cycle_len=5, 
          use_clr_beta = (40,20,0.95,0.85), 
          callbacks = callbacks)

In [None]:
pd.Series(preds).value_counts()

In [1]:
learn = ConvLearner.pretrained(f_model, md, custom_head=custom_head)
cindex = ConcordanceIndex()
callbacks = [cindex]
learn.crit = custom_loss
learn.metrics = [cindex.concordance_metric]

learn.unfreeze()
learn.models.model = to_gpu(SplitCols(learn.models.model, cols_model, sz, 3, n_cols))

NameError: name 'ConvLearner' is not defined

In [None]:
x, y = learn.data.trn_ds[0:32]
learn.models.model(T(x)).shape

In [None]:
x, y = learn.data.trn_ds[0:5]



In [None]:
learn.models.model(T(x)).shape

In [None]:
learn.lr_find()
learn.sched.plot(0)

In [None]:
cindex.reset()
learn.fit(0.1, 
          1, cycle_len=10, 
          use_clr_beta = (40,20,0.95,0.85), 
          callbacks = callbacks)