### Notebook for the complete network, dataset creation, and testing. 

In [39]:
#Import packages
import torchio as tio
from torch.utils.data import dataloader
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm
import numpy as np
import torchtuples as tt
from pycox.models import LogisticHazard
from pycox.evaluation import EvalSurv

np.random.seed(1234)
_ = torch.manual_seed(1234)

In [78]:
# Load data into dataframes 

metadata =pd.read_csv("/home/anders/Phd_Interview_project_dataset/metadata.csv")
pat_ct = metadata.loc[(metadata['Modality'] == "CT")]
patients = pat_ct["Subject ID"].to_numpy()
pat_ct.set_index('Subject ID', inplace=True)

data = pd.read_excel("Colorectal-Liver-Metastases-Clinical-data-April-2023.xlsx")
surv_data = data[["Patient-ID", "vital_status", "overall_survival_months"]].copy()
data.drop(columns=["vital_status", "overall_survival_months", "De-identify Scout Name", "months_to_DFS_progression", "vital_status_DFS", "months_to_liver_DFS_progression", "vital_status_liver_DFS", "relevant_notes", "progression_or_recurrence", "progression_or_recurrence_liveronly"], inplace=True)


# Helper functions 

def Get_pat_imgs(df, patientid):
    loc = "/home/anders/Phd_Interview_project_dataset/" + df.loc[patientid]["File Location"][2:] + "/"
    return loc

def get_pat_surv_dat(df, patientid):
    data = torch.asarray(df.loc[df["Patient-ID"] == patientid].iloc[:,1:].values[0])
    return data[0], data[1]

def Get_patient_covariate_data(df, patient_ID):
    return torch.asarray(df.loc[df["Patient-ID"] == patient_ID].iloc[:,1:].values[0]).to(torch.float32)

  warn(msg)


In [79]:
#Partition into train and val set
train_data = data
test_data = data.sample(frac=0.2)
train_data = train_data.drop(test_data.index)

surv_train_targ = surv_data.drop(test_data["Patient-ID"].index)
surv_test_targ = surv_data.drop(train_data["Patient-ID"].index)

#Perform label transform and discretization of the datapoints
labtrans = LogisticHazard.label_transform(20)
target_train = labtrans.fit_transform(surv_train_targ["overall_survival_months"].values, surv_train_targ["vital_status"].values)
target_test = labtrans.fit_transform(surv_test_targ["overall_survival_months"].values, surv_test_targ["vital_status"].values)

In [80]:
#Create a torchio dataset that is also compatible with the Pycox environment
# Dataloader needs to output this:  input, target = data

class FusionSurvDataset():
    def __init__(self, tiodataset, time, event):
        self.tiodataset = tiodataset
        self.time, self.event = tt.tuplefy(time, event).to_tensor()
    
    def __len__(self):
        return len(self.time)
    
    def __getitem__(self, indx):
        data = self.tiodataset[indx]
        img = data.image.data
        covariates = data.covariates
        return tt.tuplefy((img, covariates), (self.time[indx], self.event[indx]))

In [81]:
target_train

(array([16,  4, 13, 13,  3, 14,  4, 10, 12, 15, 12, 13, 11, 18, 16, 16,  4,
         5, 17, 12,  3,  9, 10,  8, 11, 11, 13,  3, 11,  5, 13, 10,  5,  7,
        13,  3,  5,  1, 13,  3, 10, 14,  2,  7,  6, 15, 15,  4,  6, 13, 12,
         9, 17, 10,  9, 17, 11,  5,  8,  2, 14, 15,  5, 14,  5, 10, 17, 13,
         1,  8, 17,  4, 10, 11, 12, 17,  6, 18, 14,  6,  5,  8,  3,  9,  5,
         5, 10, 18, 15, 11, 16, 13,  4,  5, 14,  8,  5,  6, 13,  7, 13, 14,
        12, 15, 14,  3,  2, 11,  9,  3,  8,  8,  8,  5, 16,  7, 11, 17,  8,
         7,  3,  7,  5,  8,  9, 16,  9,  7,  8, 15, 19, 12, 13,  8,  8, 13,
         2, 10, 10,  3, 14,  9,  6,  4,  5, 17,  9,  3, 17,  4, 15,  6,  2,
        12,  4, 14, 15, 15]),
 array([0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1.,
        1., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1.,
        1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0.

In [82]:
Transforms = tio.Compose([
    tio.Resample(1),
    tio.ZNormalization(),
    tio.Resize((128,128,240))
])

def Get_Subjects(Idlist):
    Subject_list = []
    for i,subject in enumerate(Idlist):
        dir = Get_pat_imgs(pat_ct, subject)
        covariates = Get_patient_covariate_data(data, subject)
        event, time = get_pat_surv_dat(surv_data, subject)
        subject = tio.Subject(
        image = tio.ScalarImage(dir),
        covariates = covariates,
        )
        Subject_list.append(subject)
    return Subject_list

#Training split
Train_Subjects = Get_Subjects(surv_train_targ["Patient-ID"].to_list())
Train_tio_dataset = tio.SubjectsDataset(Train_Subjects, transform=Transforms)
Train_Fusion_dataset = FusionSurvDataset(Train_tio_dataset, target_train[0], target_train[1])

#Validation / Testing split 
Test_Subjects = Get_Subjects(surv_test_targ["Patient-ID"].to_list())
Test_tio_dataset = tio.SubjectsDataset(Test_Subjects, transform=Transforms)
Test_Fusion_dataset = FusionSurvDataset(Test_tio_dataset, target_test[0], target_test[1])

In [83]:
#Need the Pycox-collate 
def collate_fn(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

In [96]:
dl_train = DataLoader(Train_Fusion_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
dl_test = DataLoader(Test_Fusion_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)


In [86]:
batch = next(iter(dl_train))
covdata = batch[0][1]
covdata.dtype

torch.float32

In [100]:
#Defining the network
class FusionSurv(nn.Module):
    
    def __init__(self, clinical_inputs=25, ct_cov=12, clin_cov=12, out_haz=32):
        super().__init__()

        self.CT_net = nn.Sequential(
            # 
            nn.Conv3d(1, 2, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm3d(2),
            nn.Dropout(p=0.2),

            nn.Conv3d(2, 4, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm3d(4),
            nn.Dropout(p=0.2),

            nn.Conv3d(4, 8, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm3d(8),
            nn.Dropout(p=0.2),

            nn.Conv3d(8, 1, kernel_size=2, stride=2),
            nn.BatchNorm3d(1), 
            nn.Flatten(start_dim=2),
            nn.Linear(960, ct_cov),
            nn.ReLU(), # some of the outputs are 0, but this should be fine?
            #nn.Sigmoid(), # create logits # I shouldn't need to apply sigmoid yet, and risk saturating the gradients.
        )

        self.Clin_net = nn.Sequential(
            nn.Linear(clinical_inputs, 32),
            nn.ReLU(),
            nn.BatchNorm1d(1),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.BatchNorm1d(1),
            nn.Linear(16, clin_cov),
            nn.ReLU(),
        )

        self.Surv_net = nn.Sequential(
            nn.Linear(ct_cov + clin_cov, 64),
            nn.ReLU(),
            nn.BatchNorm1d(1),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(1),
            nn.Linear(32, out_haz),
            nn.Sigmoid(), # Sigmoid to create logistic hazard outputs.
        )

    
    def forward(self, im, clin):
        imcov   = self.CT_net(im)
        clincov = self.Clin_net(clin.unsqueeze(dim=1))
        return self.Surv_net(torch.cat((imcov, clincov), axis=2)).squeeze(1)
    
    def predict(self, im, clin):
        # Pycox uses predict for the survival functions, but since this is a fusion net there really isn't a part of the network that should work independently.
        return self.forward(im, clin)

# takes input [Batch,1,240,128,128], and [Batch,1,25] --> returns [batch,1,haz_out] where hazard out are the hazards for each discrete step. 


In [101]:
net = FusionSurv(out_haz=labtrans.out_features)

model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts, device="cuda")

In [102]:
callbacks = [tt.cb.EarlyStopping(patience=5)]
epochs = 5
verbose = True
log = model.fit_dataloader(dl_train, epochs, callbacks, verbose, val_dataloader=dl_test)

0:	[3m:30s / 3m:30s],		train_loss: 7.8100,	val_loss: 7.3634
