### Import libraries

In [2]:
import os, sys
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [3]:
import pandas as pd
import numpy as np
import networkx as nx

In [4]:
from scipy.stats import wasserstein_distance

In [5]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [6]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [7]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [8]:
from ipywidgets import interact, interact_manual, FloatSlider

In [9]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, fetch_data, preprocess, create_dataset, \
                      DATA_FOLDERS, FILES, standardise_column_names
from src.models import DiffPool, BaselineGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/adhaene/Downloads/'

### Fetch data

In [16]:
blood = pd.read_csv(os.path.join(CONNECTION_DIR + DATA_FOLDERS[4], FILES[DATA_FOLDERS[4]]['blood']))
blood.rename(columns={feature: feature.replace('-', '_') for feature in blood.columns}, inplace=True)
# Listify immunotherapy type to create multi-feature encoding
blood['immuno_therapy_type'] = blood.immuno_therapy_type \
    .apply(lambda t: ['ipi', 'nivo'] if t == 'ipinivo' else [t])

# Filter in the patient information that we want access to
blood_features = ['sex', 'bmi', 'performance_score_ecog', 'ldh_sang_ul', 'neutro_absolus_gl',
                    'eosini_absolus_gl', 'leucocytes_sang_gl', 'NRAS_MUTATION', 'BRAF_MUTATION',
                    'immuno_therapy_type', 'lympho_absolus_gl', 'concomittant_tvec',
                    'prior_targeted_therapy', 'prior_treatment', 'nivo_maintenance']

# Transform all one-hot encoded features into True/False to avoid scaler
for feature in blood_features:
    values = blood[feature].value_counts().keys()
    if len(values) == 2 and all(values == [0, 1]):
        blood[feature] = blood[feature].astype(bool)

progression = pd.read_csv(os.path.join(CONNECTION_DIR + DATA_FOLDERS[1],
                                        FILES[DATA_FOLDERS[1]]['progression']))
progression['pseudorecist'] = progression.pseudorecist.eq('NPD').mul(1)

In [19]:
blood[['gpcr_id', 'n_days_to_treatment_start', *blood_features]]

Unnamed: 0,gpcr_id,n_days_to_treatment_start,sex,bmi,performance_score_ecog,ldh_sang_ul,neutro_absolus_gl,eosini_absolus_gl,leucocytes_sang_gl,NRAS_MUTATION,BRAF_MUTATION,immuno_therapy_type,lympho_absolus_gl,concomittant_tvec,prior_targeted_therapy,prior_treatment,nivo_maintenance
0,34610001,0.0,female,28.4,1,520.0,10.61,0.00,18.3,,n,"[ipi, nivo]",7.32,False,False,False,True
1,34610001,21.0,female,28.4,0,235.0,11.32,0.16,16.4,,n,"[ipi, nivo]",4.10,False,False,False,True
2,34610001,42.0,female,28.4,0,218.0,10.03,0.30,15.2,,n,"[ipi, nivo]",4.10,False,False,False,True
3,34610001,63.0,female,28.4,0,193.0,8.35,0.00,11.6,,n,"[ipi, nivo]",2.32,False,False,False,True
4,34610001,91.0,female,27.1,1,259.0,8.97,0.00,16.3,,n,[nivo],6.68,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1364,34610095,357.0,female,,0,,,,,n,n,[nivo],,False,False,False,True
1365,34610095,371.0,female,,0,184.0,10.88,0.29,14.7,n,n,[nivo],2.65,False,False,False,True
1366,34610095,385.0,female,,0,182.0,9.59,0.27,13.5,n,n,[nivo],2.70,False,False,False,True
1367,34610095,401.0,female,,0,188.0,,,,n,n,[nivo],,False,False,False,True
