# Constants


In [1]:

import pandas as pd
import os
import shutil
from os import path as osp

import torch
from tiatoolbox.utils.misc import select_device
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import joblib
import json
import glob

# from src.intensity import add_features_and_create_new_dicts

from src.featureextraction import get_cell_features, add_features_and_create_new_dicts
from src.train import stratified_split, recur_find_ext, run_once, rm_n_mkdir ,reset_logging
from src.graph_construct import create_graph_with_pooled_patch_nodes, get_pids_labels_for_key, create_graph_with_pooled_patch_nodes_with_survival_data


ON_GPU = False
device = select_device(on_gpu=ON_GPU)

SEED = 5
random.seed(SEED)
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


BASEDIR = '/home/amrit/data/proj_data/MLG_project/DLBCL-Morph'


STAIN = 'MYC'
# STAIN = 'BCL2'
# STAIN = 'HE'

FIDIR = f'{BASEDIR}/outputs'
CLINPATH = f'{BASEDIR}/clinical_data_cleaned.csv'
ANNPATH = f'{BASEDIR}/annotations_clean.csv'
FEATSDIR = f'{BASEDIR}/outputs/files/{STAIN}'
FEATSCALERPATH = f"{FEATSDIR}/0_feat_scaler.npz"
PATCH_SIZE = 224
OUTPUT_SIZE = PATCH_SIZE*8

WORKSPACE_DIR = Path(BASEDIR)
# GRAPH_DIR = WORKSPACE_DIR / f"graphs{STAIN}" 
# LABELS_PATH = WORKSPACE_DIR / "graphs/0_labels.txt"


# Graph construction
# PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

LABEL_TYPE = 'multilabel' #'OS' #
LABEL_TYPE = 'OS' #'OS' #


GRAPHSDIR = Path(f'{BASEDIR}/graphs/{STAIN}')
LABELSPATH = f'{BASEDIR}/graphs/{STAIN}_labels.json'

NUM_EPOCHS = 100
NUM_NODE_FEATURES = 128
NCLASSES = 3

TRAIN_DIR = WORKSPACE_DIR / "training"
SPLIT_PATH = TRAIN_DIR / f"splits_{STAIN}_{LABEL_TYPE}.dat"
# RUN_OUTPUT_DIR = TRAIN_DIR / f"session_{STAIN}_{datetime.now().strftime('%m_%d_%H_%M')}"


# Extract features: 20 + 4 minutes

In [None]:
# read annotation csv, filter, and process intensity features
df = pd.read_csv(ANNPATH)

df = df[df['stain'] == STAIN]
df['area'] = (df['xe'] - df['xs']) *  (df['ye'] - df['ys'])/10000
df = df[df['area'] >= 150]  
df = df[df['xs']  >=0 ]
df = df[df['ys']  >=0 ]
df = df[df['xe']  >=0 ]
df = df[df['ye']  >=0 ]

df = df.reset_index()
##########

###############
# add intensity features
start_index = 0
end_index = len(df.index)

datpaths = []
imgpaths = []
updatpaths = []

# for index in range(start_index, 1):
for index in range(start_index, end_index):
    df_index = df['index'][index]
    patient_id = df['patient_id'][index]
    stain = df['stain'][index]
    tma_id = df['tma_id'][index]
    unique_id = str(patient_id) + '_' + stain + '_' + str(df_index)

    img_file_name = f"{FIDIR}/images/{patient_id}/{patient_id}_{stain}_{tma_id}_{OUTPUT_SIZE}_{df_index}.png"
    dat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/0.dat"
    updat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/{unique_id}.dat"

    datpaths.append(dat_file_name)
    imgpaths.append(img_file_name)
    updatpaths.append(updat_file_name)
    


In [None]:
updatpaths[:3]

In [None]:
# 16 minutes
# shutil.rmtree(FEATSDIR)
# if not osp.exists(FEATSDIR):
#     os.makedirs(FEATSDIR)
add_features_and_create_new_dicts(datpaths, imgpaths, updatpaths)

In [None]:
# 4 minutes
gns = StandardScaler()
for featpath in tqdm(updatpaths):
    try:
        celldatadict = joblib.load(featpath)
        cellsfeats = np.array([v['intensity_feats'] for k, v in celldatadict.items()])
        gns.partial_fit(cellsfeats)
    except:
        print(featpath)

np.savez(FEATSCALERPATH, mean=gns.mean_, var=gns.var_)

# dd = np.load(FEATSCALERPATH)
# print(dd['mean'], gns.mean_)
# print(dd['var'], gns.var_)

In [None]:
FEATSCALERPATH

# Prepare graphs : 4 minutes

In [163]:
from src.graph_construct import get_pids_multilabels_for_key

df = pd.read_csv(CLINPATH)
df =df.fillna(0)
# display(df.info())


import numpy as np
import joblib
import pandas as pd
from tqdm import tqdm
import glob
import os
import shutil
from os import path as osp
import json
import cv2
from scipy.stats import skew
import colorsys
import random
from typing import Dict, List, Tuple, Union

from tiatoolbox.tools.graph import delaunay_adjacency, affinity_to_edge_index
from matplotlib import pyplot as plt


SKEW_NOISE  = 0.0001


def get_pids_labels_for_key(orig_df, key='OS', nclasses=3, idkey='patient_id'):
    ckey = key+'_class'
    df = orig_df[[idkey, key]]
    
    df[ckey] = int(-1)

    separators = np.linspace(0, 1, nclasses+1)[0:-1]
    for isep, sep in enumerate(separators):
        sepval = df[key].quantile(sep)
        sepmask = df[key] > sepval
        df.loc[sepmask, ckey] = isep

    df['Follow-up Status'] = orig_df['Follow-up Status']
    df['OS'] = orig_df['OS']
    return df


def get_pids_multilabels_for_key(orig_df, key_list, nclasses=3, idkey='patient_id'):
    df = orig_df[[idkey]+ key_list]

    for key in key_list:
        if key == "OS":
            ckey = key+'_class'
            df[ckey] = int(0)
            separators = np.linspace(0, 1, nclasses+1)[0:-1]
            for isep, sep in enumerate(separators):
                sepval = df[key].quantile(sep)
                sepmask = df[key] > sepval
                df.loc[sepmask, ckey] = isep
        elif key in ['MYC IHC', 'BCL2 IHC', 'BCL6 IHC']:
            ckey = key+'_class'
            df[ckey] = int(0)
            df[ckey][df[key] >= 40] = 1
            df[ckey][df[key] < 40] = 0
        else:
            ckey = key+'_class'
            df[ckey] = df[key]
    df['OS'] = orig_df['OS']
    return df

if LABEL_TYPE == 'multilabel':
    df_labels = get_pids_multilabels_for_key(df, key_list=['OS','MYC IHC', 'BCL2 IHC', 'BCL6 IHC', 'CD10 IHC', 'MUM1 IHC', 'Follow-up Status'], nclasses=2)
else:
    df_labels = get_pids_labels_for_key(df, key ='OS', nclasses=2)
    # survival_event_data = df[['Follow-up Status']].to_numpy().tolist()
    # survival_time_data = df[['OS']].to_numpy().tolist()


len(df_labels)

# Graph construction
PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

# # save paths
# featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
# featpaths = [x if "/0.dat" not in x for x in featpaths]
featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
featpaths = [x for x in featpaths if ("/0.dat" not in x) and ("/file_map.dat" not in x)]
featpaths
pids = [int(osp.basename(featpath).split('_')[0]) for featpath in featpaths]
df_featpaths = pd.DataFrame(zip(pids, featpaths), columns=['patient_id', 'featpath'])

# merge to find datapoints with graph data and labels
df_data = df_featpaths.merge(df_labels, on='patient_id')
# df_data = df_data[:12]

featpaths_data = df_data['featpath'].to_list()
# labels_data = df_data['OS_class'].to_list()


if LABEL_TYPE == 'multilabel':
    labels_data = df_data[['OS_class','MYC IHC_class', 'BCL2 IHC_class', 'BCL6 IHC_class', 'CD10 IHC_class', 'MUM1 IHC_class']].to_numpy().tolist()
    survival_event_data = df_data[['Follow-up Status_class']].to_numpy().tolist()
    survival_time_data = df_data[['OS']].to_numpy().tolist()

else:
    labels_data = df_data['OS_class'].to_list()
    survival_event_data = df_data[['Follow-up Status']].to_numpy().tolist()
    survival_time_data = df_data[['OS']].to_numpy().tolist()


display(labels_data[:5])

[0, 0, 1, 1, 1]

In [164]:
df_data.shape, len(survival_time_data )

((319, 5), 319)

In [165]:

import numpy as np
import joblib
import pandas as pd
from tqdm import tqdm
import glob
import os
import shutil
from os import path as osp
import json
import cv2
from scipy.stats import skew
import colorsys
import random
from typing import Dict, List, Tuple, Union

from tiatoolbox.tools.graph import delaunay_adjacency, affinity_to_edge_index
from matplotlib import pyplot as plt


SKEW_NOISE  = 0.0001


def get_pids_multilabels_for_key(df, key_list, nclasses=3, idkey='patient_id'):
    df = df[[idkey]+ key_list]

    for key in key_list:
        ckey = key+'_class'
        df[ckey] = int(0)
        separators = np.linspace(0, 1, nclasses+1)[0:-1]
        for isep, sep in enumerate(separators):
            sepval = df[key].quantile(sep)
            sepmask = df[key] > sepval
            df.loc[sepmask, ckey] = isep
    return df


def get_overall_statistics(features):
    overall_mean = np.mean(features, axis=0).tolist()
    overall_std = np.std(features, axis=0).tolist()
    overall_var = np.var(features, axis=0).tolist()
    overall_skewness = skew(features + np.random.randn(*features.shape)*SKEW_NOISE, axis=0).tolist()

    # Calculate quantiles at percentiles 0.1, 0.2, ..., 0.9
    quantiles = []
    for q in [10,25,75,90]: #range(0, 100, 10):
        quantiles += np.percentile(features, q, axis=0).tolist()

    result_list = overall_mean + overall_std + overall_var + overall_skewness + quantiles
    return result_list

global_patch_stats = []

def get_patch_pooled_positions_features(celldatadict, patch_size, cell_feat_norm_stats, MIN_CELLS_PER_PATCH):

    global global_patch_stats

    _cellfeats_gmean, _cellfeats_gvar = cell_feat_norm_stats
    cf_gmean = np.expand_dims(_cellfeats_gmean, axis=0)
    cf_gstd = np.expand_dims(np.sqrt(_cellfeats_gvar), axis=0)

    cellids = [cellid for cellid, _ in celldatadict.items()]
    cellposs = np.array([fdict['centroid'] for _, fdict in celldatadict.items()])
    gridsize = (cellposs.max(axis=0) // patch_size + 1).astype(int).tolist()

    # create list of patches
    patches_grid_cids = [[[] for _ in range(gridsize[1])] for _ in range(gridsize[0])]

    # sort cells into patches
    for cellid, cellpos in zip(cellids, cellposs):
        patch_coor = (cellpos // patch_size).astype(int).tolist()
        patches_grid_cids[patch_coor[0]][patch_coor[1]].append(cellid)

    patches_list_position = []
    patches_list_features = []

    # calc patch wise stats and add global list
    for i, patches_row_cids in enumerate(patches_grid_cids):
        for j, patch_cids in enumerate(patches_row_cids):
            # if patch has less than eq 1 cell, skip creating it 
            if len(patch_cids) <= MIN_CELLS_PER_PATCH:
                continue

            global_patch_stats.append(len(patch_cids))

            patch_cellposs = np.array([celldatadict[cellid]['centroid'] for cellid in patch_cids])
            patch_cellfeats_raw = np.array([celldatadict[cellid]['intensity_feats'] for cellid in patch_cids])
            patch_cellfeats = (patch_cellfeats_raw - cf_gmean) / cf_gstd

            patch_position = np.mean(patch_cellposs, axis=0)
            patch_features = np.array(get_overall_statistics(patch_cellfeats))

            patches_list_position.append(patch_position)
            patches_list_features.append(patch_features)

    return np.array(patches_list_position), np.array(patches_list_features)

def simple_delaunay(point_centroids, feature_centroids, connectivity_distance=4000):
    adjacency_matrix = delaunay_adjacency(
        points=point_centroids,
        dthresh=connectivity_distance,
    )
    edge_index = affinity_to_edge_index(adjacency_matrix)
    return {
        "x": feature_centroids,
        "edge_index": edge_index.astype(np.int64),
        "coordinates": point_centroids,
    }

def create_graph_with_pooled_patch_nodes(featpaths, labels, outgraphpaths, patch_size, cell_feat_norm_stats , MIN_CELLS_PER_PATCH, CONNECTIVITY_DISTANCE):

    def process_per_file_group(idx):
        featpath = featpaths[idx]
        label = labels[idx]
        outgraphpath = outgraphpaths[idx]

        celldatadict = joblib.load(featpath)
        positions, features = get_patch_pooled_positions_features(celldatadict, patch_size, cell_feat_norm_stats,MIN_CELLS_PER_PATCH)

        # graph cannot be constructed with only four patches
        try:
            graph_dict = simple_delaunay(
                positions[:, :2],
                features,
                connectivity_distance=CONNECTIVITY_DISTANCE,
            )
        except Exception as e:
            print('Skipping', featpath, 'due to', e)
        else:
            # Write a graph to a JSON file
            with open(outgraphpath, 'w+') as handle:
                # print(outgraphpath)
                graph_dict = {k: v.tolist() for k, v in graph_dict.items()}
                graph_dict['y'] = label
                json.dump(graph_dict, handle)

    joblib.Parallel(4)(
        joblib.delayed(process_per_file_group)(fidx)
        for fidx in tqdm(range(len(featpaths)), disable=False)
    )
    # for fidx in tqdm(range(len(featpaths)), disable=False):
    #     process_per_file_group(fidx)


def create_graph_with_pooled_patch_nodes_with_survival_data(featpaths, labels, survival_events, survival_time, outgraphpaths, patch_size, cell_feat_norm_stats , MIN_CELLS_PER_PATCH, CONNECTIVITY_DISTANCE):

    def process_per_file_group(idx):
        featpath = featpaths[idx]
        label = labels[idx]
        surv_event =survival_events[idx]
        surv_time =survival_time[idx]

        outgraphpath = outgraphpaths[idx]

        celldatadict = joblib.load(featpath)
        positions, features = get_patch_pooled_positions_features(celldatadict, patch_size, cell_feat_norm_stats,MIN_CELLS_PER_PATCH)

        # graph cannot be constructed with only four patches
        try:
            graph_dict = simple_delaunay(
                positions[:, :2],
                features,
                connectivity_distance=CONNECTIVITY_DISTANCE,
            )
        except Exception as e:
            print('Skipping', featpath, 'due to', e)
        else:
            # Write a graph to a JSON file
            with open(outgraphpath, 'w+') as handle:
                # print(outgraphpath)
                graph_dict = {k: v.tolist() for k, v in graph_dict.items()}
                graph_dict['y'] = label
                graph_dict['surv_event'] = surv_event
                graph_dict['surv_time'] = surv_time

                json.dump(graph_dict, handle)

    joblib.Parallel(4)(
        joblib.delayed(process_per_file_group)(fidx)
        for fidx in tqdm(range(len(featpaths)), disable=False)
    )
    # for fidx in tqdm(range(len(featpaths)), disable=False):
    #     process_per_file_group(fidx)


In [166]:
outgraphpaths_data = [f"{GRAPHSDIR}/{osp.basename(featpath).split('.')[0]}.json" for featpath in featpaths_data]

# save labels
labels_dict = {osp.basename(graphpath): label for graphpath, label in zip(outgraphpaths_data, labels_data)}
with open(LABELSPATH, 'w') as f:
    json.dump(labels_dict, f)

# read normalizer stats from file and pass to fn
dd = np.load(FEATSCALERPATH)
cell_feat_norm_stats = (dd['mean'], dd['var'])

# create final graphs data
shutil.rmtree(GRAPHSDIR)
if not osp.exists(GRAPHSDIR):
    os.makedirs(GRAPHSDIR)
    create_graph_with_pooled_patch_nodes_with_survival_data(
        featpaths_data,
        labels_data,
        survival_event_data,
        survival_time_data,
        outgraphpaths_data,
        PATCH_SIZE,
        cell_feat_norm_stats=cell_feat_norm_stats,
        MIN_CELLS_PER_PATCH= MIN_CELLS_PER_PATCH,
        CONNECTIVITY_DISTANCE = CONNECTIVITY_DISTANCE
    )
    # wont work in parallel mode
    # print(np.mean(global_patch_stats), np.std(global_patch_stats))

  0%|          | 0/319 [00:00<?, ?it/s]

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()



|2023-12-11|21:44:03.616| [INFO] Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
|2023-12-11|21:44:03.683| [INFO] Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
|2023-12-11|21:44:03.696| [INFO] Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
|2023-12-11|21:44:03.840| [INFO] Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


100%|██████████| 319/319 [01:19<00:00,  4.03it/s]


In [None]:
outgraphpaths_data = [f"{GRAPHSDIR}/{osp.basename(featpath).split('.')[0]}.json" for featpath in featpaths_data]

# save labels
labels_dict = {osp.basename(graphpath): label for graphpath, label in zip(outgraphpaths_data, labels_data)}
with open(LABELSPATH, 'w') as f:
    json.dump(labels_dict, f)

# read normalizer stats from file and pass to fn
dd = np.load(FEATSCALERPATH)
cell_feat_norm_stats = (dd['mean'], dd['var'])

# create final graphs data
shutil.rmtree(GRAPHSDIR)
if not osp.exists(GRAPHSDIR):
    os.makedirs(GRAPHSDIR)
    create_graph_with_pooled_patch_nodes(
        featpaths_data,
        labels_data,
        outgraphpaths_data,
        PATCH_SIZE,
        cell_feat_norm_stats=cell_feat_norm_stats,
        MIN_CELLS_PER_PATCH= MIN_CELLS_PER_PATCH,
        CONNECTIVITY_DISTANCE = CONNECTIVITY_DISTANCE
    )
    # wont work in parallel mode
    # print(np.mean(global_patch_stats), np.std(global_patch_stats))

# Create data splits

In [167]:
wsi_paths = recur_find_ext(GRAPHSDIR, [".json"])
wsi_names = [Path(v).stem for v in wsi_paths]
assert len(wsi_paths) > 0, "No files found."  # noqa: S101

len(wsi_paths) , len(wsi_names) , wsi_names[:5]

(319,
 319,
 ['13901_MYC_532',
  '13901_MYC_557',
  '13902_MYC_513',
  '13902_MYC_514',
  '13903_MYC_535'])

In [168]:

from __future__ import annotations

from typing import Callable
import ujson as json
import numpy as np
import pathlib
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold

import torch
from torch.utils.data import Sampler
from torch_geometric.data import Data, Dataset


class SlideGraphDataset(Dataset):
    """Handling loading graph data from disk.

    Args:
        info_list (list): In case of `train` or `valid` is in `mode`,
            this is expected to be a list of `[uid, label]` . Otherwise,
            it is a list of `uid`. Here, `uid` is used to construct
            `f"{GRAPH_DIR}/{wsi_code}.json"` which is a path points to
            a `.json` file containing the graph structure. By `label`, we mean
            the label of the graph. The format within the `.json` file comes
            from `tiatoolbox.tools.graph`.
        mode (str): This denotes which data mode the `info_list` is in.
        preproc (callable): The prerocessing function for each node
            within the graph.

    """

    def __init__(
        self: Dataset,
        info_list: list,
        mode: str = "train",
        graph_dir: pathlib.Path = None,
        preproc: Callable | None = None,
    ) -> None:
        """Initialize SlideGraphDataset."""
        self.info_list = info_list
        self.mode = mode
        self.graph_dir = graph_dir
        self.preproc = preproc

    def __getitem__(self: Dataset, idx: int) -> Dataset:
        """Get an element from SlideGraphDataset."""
        info = self.info_list[idx]
        if any(v in self.mode for v in ["train", "valid"]):
            wsi_code, label = info
            # torch.Tensor will create 1-d vector not scalar
            label = torch.tensor(label)
        else:
            wsi_code = info

        with (self.graph_dir / f"{wsi_code}.json").open() as fptr:
            graph_dict = json.load(fptr)
        graph_dict = {k: np.array(v) for k, v in graph_dict.items()}

        if self.preproc is not None:
            graph_dict["x"] = self.preproc(graph_dict["x"])

        graph_dict = {k: torch.tensor(v) for k, v in graph_dict.items()}
        graph = Data(**graph_dict)

        if any(v in self.mode for v in ["train", "valid"]):
            return {"graph": graph, "label": label}
        return {"graph": graph}

    def __len__(self: Dataset) -> int:
        """Length of SlideGraphDataset."""
        return len(self.info_list)

    def len(self):
        return self.__len__()

    def get(self, idx):
        return self.__getitem__(idx)


# def stratified_split(
#     x: list,
#     y: list,
#     train: float,
#     valid: float,
#     test: float,
#     num_folds: int,
#     seed: int = 5,
# ) -> list:
#     """Helper to generate stratified splits.

#     Split `x` and `y` in to N number of `num_folds` sets
#     of `train`, `valid`, and `test` set in stratified manner.
#     `train`, `valid`, and `test` are guaranteed to be mutually
#     exclusive.

#     Args:
#         x (list, np.ndarray):
#             List of samples.
#         y (list, np.ndarray):
#             List of labels, each value is the value
#             of the sample at the same index in `x`.
#         train (float):
#             Percentage to be used for training set.
#         valid (float):
#             Percentage to be used for validation set.
#         test (float):
#             Percentage to be used for testing set.
#         num_folds (int):
#             Number of split generated.
#         seed (int):
#             Random seed. Default=5.

#     Returns:
#         A list of splits where each is a dictionary of
#         {
#             'train': [(sample_A, label_A), (sample_B, label_B), ...],
#             'valid': [(sample_C, label_C), (sample_D, label_D), ...],
#             'test' : [(sample_E, label_E), (sample_E, label_E), ...],
#         }

#     """
#     assert (  # noqa: S101
#         train + valid + test - 1.0 < 1.0e-10  # noqa: PLR2004
#     ), "Ratios must sum to 1.0 ."

#     outer_splitter = StratifiedShuffleSplit(
#         n_splits=num_folds,
#         train_size=train + valid,
#         random_state=seed,
#     )
#     inner_splitter = StratifiedShuffleSplit(
#         n_splits=1,
#         train_size=train / (train + valid),
#         random_state=seed,
#     )

#     x = np.array(x)
#     y = np.array(y)
#     splits = []
#     for train_valid_idx, test_idx in outer_splitter.split(x, y):
#         test_x = x[test_idx]
#         test_y = y[test_idx]

#         # Holder for train_valid set
#         x_ = x[train_valid_idx]
#         y_ = y[train_valid_idx]

#         # Split train_valid into train and valid set
#         train_idx, valid_idx = next(iter(inner_splitter.split(x_, y_)))
#         valid_x = x_[valid_idx]
#         valid_y = y_[valid_idx]

#         train_x = x_[train_idx]
#         train_y = y_[train_idx]

#         # Integrity check
#         assert len(set(train_x).intersection(set(valid_x))) == 0  # noqa: S101
#         assert len(set(valid_x).intersection(set(test_x))) == 0  # noqa: S101
#         assert len(set(train_x).intersection(set(test_x))) == 0  # noqa: S101

#         splits.append(
#             {
#                 "train": list(zip(train_x, train_y)),
#                 "valid": list(zip(valid_x, valid_y)),
#                 "test": list(zip(test_x, test_y)),
#             },
#         )
#     return splits



def stratified_split(
    x: list,
    y: list,
    train: float,
    valid: float,
    test: float,
    num_folds: int,
    seed: int = 5,
) -> list:
    """Helper to generate stratified splits.

    Split `x` and `y` in to N number of `num_folds` sets
    of `train`, `valid`, and `test` set in stratified manner.
    `train`, `valid`, and `test` are guaranteed to be mutually
    exclusive.

    Args:
        x (list, np.ndarray):
            List of samples.
        y (list, np.ndarray):
            List of labels, each value is the value
            of the sample at the same index in `x`.
        train (float):
            Percentage to be used for training set.
        valid (float):
            Percentage to be used for validation set.
        test (float):
            Percentage to be used for testing set.
        num_folds (int):
            Number of split generated.
        seed (int):
            Random seed. Default=5.

    Returns:
        A list of splits where each is a dictionary of
        {
            'train': [(sample_A, label_A), (sample_B, label_B), ...],
            'valid': [(sample_C, label_C), (sample_D, label_D), ...],
            'test' : [(sample_E, label_E), (sample_E, label_E), ...],
        }

    """
    assert (  # noqa: S101
        train + valid + test - 1.0 < 1.0e-10  # noqa: PLR2004
    ), "Ratios must sum to 1.0 ."

    outer_splitter = StratifiedShuffleSplit(
        n_splits=num_folds,
        train_size=train + valid,
        random_state=seed,
    )
    inner_splitter = StratifiedShuffleSplit(
        n_splits=1,
        train_size=train / (train + valid),
        random_state=seed,
    )

    x = np.array(x)
    y = np.array(y)

    print(x,y)
    splits = []
    for train_valid_idx, test_idx in outer_splitter.split(x, y):
        test_x = x[test_idx]
        test_y = y[test_idx]

        # Holder for train_valid set
        x_ = x[train_valid_idx]
        y_ = y[train_valid_idx]

        # Split train_valid into train and valid set
        train_idx, valid_idx = next(iter(inner_splitter.split(x_, y_)))
        valid_x = x_[valid_idx]
        valid_y = y_[valid_idx]

        train_x = x_[train_idx]
        train_y = y_[train_idx]

        # Integrity check
        assert len(set(train_x).intersection(set(valid_x))) == 0  # noqa: S101
        assert len(set(valid_x).intersection(set(test_x))) == 0  # noqa: S101
        assert len(set(train_x).intersection(set(test_x))) == 0  # noqa: S101

        splits.append(
            {
                "train": list(zip(train_x, train_y)),
                "valid": list(zip(valid_x, valid_y)),
                "test": list(zip(test_x, test_y)),
            },
        )
    return splits


def multilabel_stratified_split(
    x: list,
    y: list,
    train: float,
    valid: float,
    test: float,
    num_folds: int,
    seed: int = 5,
) -> list:
    """Helper to generate stratified splits.

    Split `x` and `y` in to N number of `num_folds` sets
    of `train`, `valid`, and `test` set in stratified manner.
    `train`, `valid`, and `test` are guaranteed to be mutually
    exclusive.

    Args:
        x (list, np.ndarray):
            List of samples.
        y (list, np.ndarray):
            List of labels, each value is the value
            of the sample at the same index in `x`.
        train (float):
            Percentage to be used for training set.
        valid (float):
            Percentage to be used for validation set.
        test (float):
            Percentage to be used for testing set.
        num_folds (int):
            Number of split generated.
        seed (int):
            Random seed. Default=5.

    Returns:
        A list of splits where each is a dictionary of
        {
            'train': [(sample_A, label_A), (sample_B, label_B), ...],
            'valid': [(sample_C, label_C), (sample_D, label_D), ...],
            'test' : [(sample_E, label_E), (sample_E, label_E), ...],
        }

    """
    assert (  # noqa: S101
        train + valid + test - 1.0 < 1.0e-10  # noqa: PLR2004
    ), "Ratios must sum to 1.0 ."

    outer_splitter = StratifiedShuffleSplit(
        n_splits=num_folds,
        train_size=train + valid,
        random_state=seed,
    )
    inner_splitter = StratifiedShuffleSplit(
        n_splits=1,
        train_size=train / (train + valid),
        random_state=seed,
    )

    x = np.array(x)
    y = np.array(y)
    splits = []
    for train_valid_idx, test_idx in outer_splitter.split(x, y[:,0]):
        test_x = x[test_idx]
        test_y = y[test_idx]

        # Holder for train_valid set
        x_ = x[train_valid_idx]
        y_ = y[train_valid_idx]

        # Split train_valid into train and valid set
        train_idx, valid_idx = next(iter(inner_splitter.split(x_, y_[:,0])))
        valid_x = x_[valid_idx]
        valid_y = y_[valid_idx]

        train_x = x_[train_idx]
        train_y = y_[train_idx]

        # Integrity check
        assert len(set(train_x).intersection(set(valid_x))) == 0  # noqa: S101
        assert len(set(valid_x).intersection(set(test_x))) == 0  # noqa: S101
        assert len(set(train_x).intersection(set(test_x))) == 0  # noqa: S101

        splits.append(
            {
                "train": list(zip(train_x, train_y)),
                "valid": list(zip(valid_x, valid_y)),
                "test": list(zip(test_x, test_y)),
            },
        )
    return splits


class StratifiedSampler(Sampler):
    """Sampling the dataset such that the batch contains stratified samples.

    Args:
        labels (list): List of labels, must be in the same ordering as input
            samples provided to the `SlideGraphDataset` object.
        batch_size (int): Size of the batch.

    Returns:
        List of indices to query from the `SlideGraphDataset` object.

    """

    def __init__(self: Sampler, labels: list, batch_size: int = 10) -> None:
        """Initialize StratifiedSampler."""
        self.batch_size = batch_size
        self.num_splits = int(len(labels) / self.batch_size)
        self.labels = labels
        self.num_steps = self.num_splits

    def _sampling(self: Sampler) -> list:
        """Do we want to control randomness here."""
        skf = StratifiedKFold(n_splits=self.num_splits, shuffle=True)
        indices = np.arange(len(self.labels))  # idx holder
        # return array of arrays of indices in each batch
        return [tidx for _, tidx in skf.split(indices, self.labels)]

    def __iter__(self: Sampler) -> Iterator:
        """Define Iterator."""
        return iter(self._sampling())

    def __len__(self: Sampler) -> int:
        """The length of the sampler.

        This value actually corresponds to the number of steps to query
        sampled batch indices. Thus, to maintain epoch and steps hierarchy,
        this should be equal to the number of expected steps as in usual
        sampling: `steps=dataset_size / batch_size`.

        """
        return self.num_steps

from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Sampler



class StratifiedSampler_multilabel(Sampler):
    """Sampling the dataset such that the batch contains stratified samples.

    Args:
        labels (list): List of labels, must be in the same ordering as input
            samples provided to the `SlideGraphDataset` object.
        batch_size (int): Size of the batch.

    Returns:
        List of indices to query from the `SlideGraphDataset` object.

    """

    def __init__(self, labels: list, batch_size: int = 10):
        """Initialize StratifiedSampler."""
        self.batch_size = batch_size
        self.labels = np.array(labels)
        self.sss = StratifiedShuffleSplit(n_splits=1, test_size=batch_size / len(labels), random_state=42)

    def _sampling(self):
        """Do we want to control randomness here."""
        indices = np.arange(len(self.labels))
        _, tidx = next(self.sss.split(indices, self.labels[:,0]))
        # return array of arrays of indices in each batch
        return [tidx]

    def __iter__(self):
        """Define Iterator."""
        return iter(self._sampling())

    def __len__(self):
        """The length of the sampler."""
        return len(self.labels) // self.batch_size



In [169]:
from src.dset import multilabel_stratified_split

NUM_FOLDS = 1
TEST_RATIO = 0.2
TRAIN_RATIO = 0.8 * 0.7
VALID_RATIO = 0.8 * 0.3

# if SPLIT_PATH and os.path.exists(SPLIT_PATH):
#     splits = joblib.load(SPLIT_PATH)
# else:
x = np.array(wsi_names)
with open(LABELSPATH, 'r') as f:
    labels_dict = json.load(f)
    print(labels_dict)
print(labels_dict)
y = np.array([labels_dict[wsi_name+'.json'] for wsi_name in wsi_names])
# y[np.where(y==-1)] = 0

# splits = multilabel_stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)

if LABEL_TYPE == "multilabel":
    splits = multilabel_stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)
else:
    splits = stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)

joblib.dump(splits, SPLIT_PATH)


{'13901_MYC_532.json': 0, '13901_MYC_557.json': 0, '13902_MYC_513.json': 1, '13902_MYC_514.json': 1, '13903_MYC_535.json': 1, '13903_MYC_560.json': 1, '13904_MYC_539.json': 0, '13904_MYC_550.json': 0, '13908_MYC_527.json': 1, '13908_MYC_555.json': 1, '13911_MYC_520.json': 1, '13911_MYC_533.json': 1, '13912_MYC_530.json': 0, '13913_MYC_536.json': 0, '13913_MYC_554.json': 0, '13914_MYC_534.json': 0, '13914_MYC_556.json': 0, '13915_MYC_516.json': 1, '13915_MYC_547.json': 1, '13917_MYC_526.json': 1, '13917_MYC_551.json': 1, '13919_MYC_518.json': 0, '13919_MYC_528.json': 0, '13920_MYC_549.json': 1, '13922_MYC_558.json': 1, '13923_MYC_525.json': 1, '13923_MYC_538.json': 1, '13924_MYC_515.json': 1, '13924_MYC_548.json': 1, '13952_MYC_2587.json': 1, '13952_MYC_2589.json': 1, '13953_MYC_2611.json': 1, '13953_MYC_2639.json': 1, '13954_MYC_2592.json': 1, '13955_MYC_2637.json': 0, '13957_MYC_2646.json': 0, '13958_MYC_2600.json': 1, '13958_MYC_2640.json': 1, '13959_MYC_2593.json': 1, '13959_MYC_264

['/home/amrit/data/proj_data/MLG_project/DLBCL-Morph/training/splits_MYC_OS.dat']

In [19]:
SPLIT_PATH

PosixPath('/home/amrit/data/proj_data/MLG_project/DLBCL-Morph/training/splits_MYC_OS.dat')

# Look at samples

In [170]:

from __future__ import annotations

from typing import Callable
import ujson as json
import numpy as np
import pathlib
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold

import torch
from torch.utils.data import Sampler
from torch_geometric.data import Data, Dataset


class SlideGraphDataset_surv(Dataset):
    """Handling loading graph data from disk.

    Args:
        info_list (list): In case of `train` or `valid` is in `mode`,
            this is expected to be a list of `[uid, label]` . Otherwise,
            it is a list of `uid`. Here, `uid` is used to construct
            `f"{GRAPH_DIR}/{wsi_code}.json"` which is a path points to
            a `.json` file containing the graph structure. By `label`, we mean
            the label of the graph. The format within the `.json` file comes
            from `tiatoolbox.tools.graph`.
        mode (str): This denotes which data mode the `info_list` is in.
        preproc (callable): The prerocessing function for each node
            within the graph.

    """

    def __init__(
        self: Dataset,
        info_list: list,
        mode: str = "train",
        graph_dir: pathlib.Path = None,
        preproc: Callable | None = None,
    ) -> None:
        """Initialize SlideGraphDataset."""
        self.info_list = info_list
        self.mode = mode
        self.graph_dir = graph_dir
        self.preproc = preproc

    def __getitem__(self: Dataset, idx: int) -> Dataset:
        """Get an element from SlideGraphDataset."""
        info = self.info_list[idx]
        if any(v in self.mode for v in ["train", "valid"]):
            wsi_code, label = info
            # torch.Tensor will create 1-d vector not scalar
            label = torch.tensor(label)
        else:
            wsi_code = info

        with (self.graph_dir / f"{wsi_code}.json").open() as fptr:
            graph_dict = json.load(fptr)

        surv_event =  graph_dict.pop("surv_event")
        surv_time =  graph_dict.pop("surv_time")

        
        graph_dict = {k: np.array(v) for k, v in graph_dict.items()}

        if self.preproc is not None:
            graph_dict["x"] = self.preproc(graph_dict["x"])

        graph_dict = {k: torch.tensor(v) for k, v in graph_dict.items()}
        graph = Data(**graph_dict)

        if any(v in self.mode for v in ["train", "valid"]):
            return {"graph": graph, "label": label, "surv_event" : surv_event, "surv_time" : surv_time}
        return {"graph": graph}

    def __len__(self: Dataset) -> int:
        """Length of SlideGraphDataset."""
        return len(self.info_list)

    def len(self):
        return self.__len__()

    def get(self, idx):
        return self.__getitem__(idx)


In [171]:
from src.utils import load_json, rm_n_mkdir, mkdir, recur_find_ext
from src.dset import SlideGraphDataset, stratified_split, StratifiedSampler
import warnings
warnings.filterwarnings("ignore")

from torch_geometric.loader import DataLoader

nodes_preproc_func = None

def sample_data(  # noqa: C901, PLR0912, PLR0915
    dataset_dict: dict,
    GRAPH_DIR = None) -> list:
    
    for subset_name, subset in dataset_dict.items():
        ds = SlideGraphDataset_surv(subset, mode="train", preproc=nodes_preproc_func, graph_dir=GRAPH_DIR)

        batch_sampler = None

        _loader_kwargs = {
            "num_workers": 6,
            "batch_size": 32,
        }
        # # arch_kwargs = {

        # print("subset" ,subset_name , len(subset))
        
        # batch_sampler = batch_sampler = StratifiedSampler(
        #             labels=[v[1] for v in subset],
        #             batch_size=32,
        #         )

        loader =  DataLoader(ds,
                    batch_sampler=batch_sampler,
            drop_last=subset_name == "train" and batch_sampler is None,
            shuffle=subset_name == "train" and batch_sampler is None,
            **_loader_kwargs,)
        
    return ds, loader


for split_idx, split in enumerate(splits):
    new_split = {
                "train": split["train"],
                "infer-train": split["train"],
                "infer-valid-A": split["valid"],
                "infer-valid-B": split["test"],
            }
    ds, loader = sample_data(new_split, GRAPHSDIR)

sample = ds.__getitem__(3)
display(sample)

for _step, batch_data in enumerate(loader):
    print(batch_data)
    break

{'graph': Data(x=[95, 128], edge_index=[2, 506], y=1, coordinates=[95, 2]),
 'label': tensor(1),
 'surv_event': [0],
 'surv_time': [12.72]}

{'graph': DataBatch(x=[2125, 128], edge_index=[2, 11016], y=[32], coordinates=[2125, 2], batch=[2125], ptr=[33]), 'label': tensor([1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0,
        1, 0, 1, 1, 0, 1, 0, 1]), 'surv_event': [tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0,
        1, 1, 0, 0, 1, 0, 1, 0])], 'surv_time': [tensor([ 8.1100, 10.5200,  2.7800, 12.7200,  4.7800,  9.0700,  8.2700,  7.4200,
         8.5400,  6.5900,  7.7800,  8.6300,  4.4000,  4.2300,  8.3600,  8.3400,
         7.6200,  3.8100,  9.0800, 11.6700,  8.3300,  9.8400,  1.3000,  5.4100,
        10.8900,  3.4400,  8.4400,  8.7200,  2.0600,  8.0400,  4.9000, 13.8800])]}


# Train model

In [172]:

"""Import modules required to run the Jupyter notebook."""
from __future__ import annotations

# Clear logger to use tiatoolbox.logger
import logging

if logging.getLogger().hasHandlers():
    logging.getLogger().handlers.clear()

logging.basicConfig(
    level=logging.INFO,
)

import copy
import random
import warnings
from pathlib import Path

# Third party imports
import joblib
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from datetime import datetime
from sklearn import metrics

import ujson as json
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from tiatoolbox import logger

from tiatoolbox.utils.misc import save_as_json


from src.utils import load_json, rm_n_mkdir, mkdir, recur_find_ext
from src.dset import SlideGraphDataset, stratified_split, StratifiedSampler, StratifiedSampler_multilabel
from src.model import SlideGraphArch
from src.utils import ScalarMovingAverage

warnings.filterwarnings("ignore")
mpl.rcParams["figure.dpi"] = 300  # for high resolution figure in notebook


nodes_preproc_func = None


import numpy as np

def multilabel_accuracy(y_pred, y_true):
    """
    Calculate multilabel classification metrics.

    Parameters:
    - y_true: 2D array-like or binary indicator matrix, true labels
    - y_pred: 2D array-like or binary indicator matrix, predicted labels

    Returns:
    - overall_accuracy: Subset accuracy (exact match ratio)
    - label_accuracies: List of accuracies for each individual label
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred[y_pred > 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    # Ensure the shapes are the same
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes of y_true and y_pred must be the same.")
    
    label_accuracies = np.mean(y_true == y_pred, axis=0).round(4)
    overall_accuracy = np.mean(label_accuracies).round(4)

    return overall_accuracy, label_accuracies

def run_once(  # noqa: C901, PLR0912, PLR0915
    dataset_dict: dict,
    num_epochs: int,
    save_dir: str | Path,
    pretrained: str | None = None,
    loader_kwargs: dict | None = None,
    arch_kwargs: dict | None = None,
    optim_kwargs: dict | None = None,
    *,
    on_gpu: bool = True,
    GRAPH_DIR = None,
    LABEL_TYPE = None
) -> list:
    """Running the inference or training loop once.

    The actual running mode is defined via the code name of the dataset
    within `dataset_dict`. Here, `train` is specifically preserved for
    the dataset used for training. `.*infer-valid.*` and `.*infer-train*`
    are reserved for datasets containing the corresponding labels.
    Otherwise, the dataset is assumed to be for the inference run.

    """
    if loader_kwargs is None:
        loader_kwargs = {}

    if arch_kwargs is None:
        arch_kwargs = {}

    if optim_kwargs is None:
        optim_kwargs = {}

    if on_gpu == True:
        device = "cuda"
    else:
        device = "cpu"

    model = SlideGraphArch(**arch_kwargs)
    print(model)
    if pretrained is not None:
        model.load(*pretrained)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), **optim_kwargs)

    # Create the graph dataset holder for each subset info then
    # pipe them through torch/torch geometric specific loader
    # for loading in multi-thread.
    loader_dict = {}
    for subset_name, subset in dataset_dict.items():
        _loader_kwargs = copy.deepcopy(loader_kwargs)
        batch_sampler = None
        if subset_name == "train":
            _loader_kwargs = {}
            if LABEL_TYPE == 'multilabel':
                batch_sampler = StratifiedSampler_multilabel(
                    labels=[v[1] for v in subset],
                    batch_size=loader_kwargs["batch_size"],
                )
            else:
                batch_sampler = StratifiedSampler(
                    labels=[v[1] for v in subset],
                    batch_size=loader_kwargs["batch_size"],
                )

        ds = SlideGraphDataset_surv(subset, mode=subset_name, preproc=nodes_preproc_func, graph_dir=GRAPH_DIR)
        loader_dict[subset_name] = DataLoader(
            ds,
            batch_sampler=batch_sampler,
            drop_last=subset_name == "train" and batch_sampler is None,
            shuffle=subset_name == "train" and batch_sampler is None,
            **_loader_kwargs,
        )
    best_score = {}
    best_score_num = {}
    best_score[f"infer-train-accuracy"] = 0
    best_score[f"infer-valid-A-accuracy"] = 0
    best_score[f"infer-valid-B-accuracy"] = 0


    for epoch in range(num_epochs):
        logger.info("EPOCH: %03d", epoch)
        for loader_name, loader in loader_dict.items():
            # * EPOCH START
            step_output = []
            step_surv_event= []
            step_surv_time= []


            ema = ScalarMovingAverage()
            for _step, batch_data in enumerate(tqdm(loader, disable=loader_name!="train")):
                # * STEP COMPLETE CALLBACKS
                if loader_name == "train":
                    outputs = model.train_batch(model, batch_data, optimizer, on_gpu=on_gpu)
                    ema({"loss": outputs[0]})
                else:
                    output = model.infer_batch(model, batch_data, on_gpu=on_gpu)

                    batch_size = batch_data["graph"].num_graphs
                    
                    surv_event_batch = batch_data["surv_event"]
                    surv_time_batch = batch_data["surv_time"]


                    #print("surv_label_batch" , surv_label_batch)
                    # Iterate over output head and retrieve
                    # each as N x item, each item may be of
                    # arbitrary dimensions
                    output = [np.split(v, batch_size, axis=0) for v in output]
                    # pairing such that it will be
                    # N batch size x H head list
                    output = list(zip(*output))
                    step_output.extend(output)


                    surv_event_batch = list(zip(*surv_event_batch))
                    step_surv_event.extend(surv_event_batch)

                    surv_time_batch = list(zip(*surv_time_batch))
                    step_surv_time.extend(surv_time_batch)
                    #print("step_surv_label" , step_surv_label)


                # pbar.update()
            # pbar.close()

            # * EPOCH COMPLETE

            # Callbacks to process output
            logging_dict = {}
            
            if loader_name == "train":
                for val_name, val in ema.tracking_dict.items():
                    logging_dict[f"train-EMA-{val_name}"] = val.item()
            elif "infer" in loader_name and any(v in loader_name for v in ["train", "valid"]):
                # Expand the list of N dataset size x H heads
                # back to a list of H Head each with N samples.
                output = list(zip(*step_output))
                pred, gtruth = output
                pred = np.squeeze(np.array(pred))
                gtruth = np.squeeze(np.array(gtruth))
                
                surv_event_list = list(zip(*step_surv_event))
                surv_event_list = np.squeeze(np.array(surv_event_list))
                # print("pred",loader_name , pred.shape)
                # print("gtruth",loader_name , gtruth.shape)
                # print("surv_event_list",loader_name , surv_event_list.shape)

                surv_time_list = list(zip(*step_surv_time))
                surv_time_list = np.squeeze(np.array(surv_time_list))
                # print("surv_time_list" ,loader_name, surv_time_list.shape)

                # if "valid" in loader_name:
                #     print(loader_name, pred) # gtruth)

                if LABEL_TYPE == 'multilabel':
                    # print(pred, gtruth)
                    curr_score, label_accuracies = multilabel_accuracy(pred, gtruth)
                    logging_dict[f"{loader_name}-accuracy"] = float(curr_score)
                    # print(loader_name , "label_accuracies" , label_accuracies)
                    # logging_dict[f"{loader_name}-individual_accuracy"] = float(label_accuracies)
                else:
                    # logging_dict[f"{loader_name}-microf1"] = metrics.f1_score(pred, gtruth, average='micro')
                    curr_score = round(metrics.accuracy_score(np.argmax(pred, axis=1), gtruth),3)
                    logging_dict[f"{loader_name}-accuracy"] = float(curr_score)
                # try:
                if curr_score >= best_score[f"{loader_name}-accuracy"]:
                    best_score[f"{loader_name}-accuracy"] = curr_score
                    best_score[f"{loader_name}-best_epoch"] = epoch
                    if LABEL_TYPE == 'multilabel':
                        best_score[f"{loader_name}-label_accuracies"] = label_accuracies.tolist()

                    best_score_num[f"{loader_name}-pred"] = pred
                    best_score_num[f"{loader_name}-gt"] = gtruth
                    best_score_num[f"{loader_name}-surv_time_list"] = surv_time_list
                    best_score_num[f"{loader_name}-surv_event_list"] = surv_event_list



                    best_best_Score = best_score_num

                # except:
                #     best_score[f"{loader_name}-accuracy"] = 0
                # logging_dict[f"{loader_name}-raw-pred"] = pred
                # logging_dict[f"{loader_name}-raw-gtruth"] = gtruth

            # Callbacks for logging and saving
            for val_name, val in logging_dict.items():
                if "raw" not in val_name:
                    logging.info("%s: %f", val_name, val)
            if "train" not in loader_dict:
                continue

            # Track the statistics
            new_stats = {}
            if (save_dir / "stats.json").exists():
                old_stats = load_json(save_dir/"stats.json")
                # Save a backup first
                save_as_json(old_stats, save_dir/"stats.old.json", exist_ok=True)
                new_stats = copy.deepcopy(old_stats)
                # new_stats = {int(k.value()): v for k, v in new_stats.items()}
                new_stats = {int(k): v for k, v in new_stats.items()}

            old_epoch_stats = {}
            if epoch in new_stats:
                old_epoch_stats = new_stats[epoch]
            old_epoch_stats.update(logging_dict)
            new_stats[epoch] = old_epoch_stats
            save_as_json(new_stats, save_dir/"stats.json", exist_ok=True)

            # Save the dictionary to a JSON file
            with open(save_dir/"best_score.json", 'w') as json_file:
                json.dump(best_score, json_file, indent=4)

        plt.figure()
        for pkey in new_stats[0].keys():
            vals = [new_stats[eitr][pkey] for eitr in range(epoch+1)]
            plt.plot(np.arange(len(vals)), vals, label=pkey)
        plt.title("Best_acc" + str(list(best_score.values())))
        plt.tight_layout()
        plt.legend()
        plt.savefig(save_dir/'progress.png')
        plt.close()

        if epoch % 25 == 0:
            # Save the pytorch model
            model.save(
                f"{save_dir}/epoch={epoch:03d}.weights.pth",
                f"{save_dir}/epoch={epoch:03d}.aux.dat",
            )
            print("best_score" , best_score)
    
    print(best_score)

    return pred, gtruth ,surv_event_list , surv_time_list , best_best_Score

In [173]:
# # we must define the function after training/loading
# def nodes_preproc_func(node_features: np.ndarray) -> np.ndarray:
#     """Pre-processing function for nodes."""
#     return node_scaler.transform(node_features)
nodes_preproc_func = None


splits = joblib.load(SPLIT_PATH)
loader_kwargs = {
    "num_workers": 6,
    "batch_size": 32,
}
# arch_kwargs = {
#     "dim_features": NUM_NODE_FEATURES,
#     "dim_target": NCLASSES,
#     "layers": [32, 32, 16, 8],
#     "dropout": 0.3,
#     "pooling": "mean",
#     "conv": "EdgeConv",
#     "aggr": "max",
# }

if LABEL_TYPE == "OS":
    NCLASSES = 2
else:
    NCLASSES = 6
    
conv = "EdgeConv"
pooling = "mean"
aggr = "max"
dropout= 0.15
layers = [64, 32, 32]

arch_kwargs = {
        "dim_features": NUM_NODE_FEATURES,
        "dim_target": NCLASSES,
        "layers": layers,
        "dropout": dropout,
        "pooling": pooling,
        "conv": conv,
        "aggr": aggr,
        "CLASSIFICATION_TYPE" : LABEL_TYPE
}

RUN_OUTPUT_DIR = TRAIN_DIR / f"session_{STAIN}_{conv}_{pooling}_{aggr}_{str(dropout)}_{str(layers)}_{datetime.now().strftime('%m_%d_%H_%M')}"
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR = RUN_OUTPUT_DIR / "model"

optim_kwargs = {
    "lr": 5.0e-4,
    "weight_decay": 1.0e-4,
}

NUM_EPOCHS = 100
# if not MODEL_DIR.exists() or True:
for split_idx, split in enumerate(splits):
        new_split = {
            "train": split["train"],
            "infer-train": split["train"],
            "infer-valid-A": split["valid"],
            # "infer-valid-B": split["test"],
        }
        MODEL_DIR = Path(MODEL_DIR) 
        split_save_dir = MODEL_DIR / f"{split_idx:02d}/"
        rm_n_mkdir(split_save_dir)
        reset_logging(split_save_dir)
        output = run_once(new_split,
            NUM_EPOCHS,
            save_dir=split_save_dir,
            arch_kwargs=arch_kwargs,
            loader_kwargs=loader_kwargs,
            optim_kwargs=optim_kwargs,
            on_gpu=ON_GPU,
            GRAPH_DIR=GRAPHSDIR,
            LABEL_TYPE = LABEL_TYPE)

|2023-12-11|21:45:38.338| [INFO] EPOCH: 000


SlideGraphArch(
  (first_h): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (nns): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=128, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (convs): ModuleList(
    (0): EdgeConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    ))
    (1): EdgeConv(nn=Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05,

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 12.41it/s]
|2023-12-11|21:45:38.744| [INFO] train-EMA-loss: 0.863734
|2023-12-11|21:45:39.127| [INFO] infer-train-accuracy: 0.500000
|2023-12-11|21:45:39.440| [INFO] infer-valid-A-accuracy: 0.494000
|2023-12-11|21:45:39.725| [INFO] EPOCH: 001


best_score {'infer-train-accuracy': 0.5, 'infer-valid-A-accuracy': 0.494, 'infer-valid-B-accuracy': 0, 'infer-train-best_epoch': 0, 'infer-valid-A-best_epoch': 0}


100%|██████████| 5/5 [00:00<00:00, 12.01it/s]
|2023-12-11|21:45:40.144| [INFO] train-EMA-loss: 0.668020
|2023-12-11|21:45:40.540| [INFO] infer-train-accuracy: 0.556000
|2023-12-11|21:45:40.843| [INFO] infer-valid-A-accuracy: 0.584000
|2023-12-11|21:45:41.135| [INFO] EPOCH: 002
100%|██████████| 5/5 [00:00<00:00, 12.17it/s]
|2023-12-11|21:45:41.549| [INFO] train-EMA-loss: 0.701116
|2023-12-11|21:45:41.915| [INFO] infer-train-accuracy: 0.601000
|2023-12-11|21:45:42.219| [INFO] infer-valid-A-accuracy: 0.636000
|2023-12-11|21:45:42.523| [INFO] EPOCH: 003
100%|██████████| 5/5 [00:00<00:00, 12.92it/s]
|2023-12-11|21:45:42.913| [INFO] train-EMA-loss: 0.642792
|2023-12-11|21:45:43.288| [INFO] infer-train-accuracy: 0.601000
|2023-12-11|21:45:43.627| [INFO] infer-valid-A-accuracy: 0.597000
|2023-12-11|21:45:43.919| [INFO] EPOCH: 004
100%|██████████| 5/5 [00:00<00:00, 13.18it/s]
|2023-12-11|21:45:44.301| [INFO] train-EMA-loss: 0.619204
|2023-12-11|21:45:44.682| [INFO] infer-train-accuracy: 0.63500

best_score {'infer-train-accuracy': 0.725, 'infer-valid-A-accuracy': 0.636, 'infer-valid-B-accuracy': 0, 'infer-train-best_epoch': 25, 'infer-valid-A-best_epoch': 2}


100%|██████████| 5/5 [00:00<00:00, 10.65it/s]
|2023-12-11|21:46:16.240| [INFO] train-EMA-loss: 0.538121
|2023-12-11|21:46:16.651| [INFO] infer-train-accuracy: 0.713000
|2023-12-11|21:46:16.981| [INFO] infer-valid-A-accuracy: 0.532000
|2023-12-11|21:46:17.280| [INFO] EPOCH: 027
100%|██████████| 5/5 [00:00<00:00, 12.70it/s]
|2023-12-11|21:46:17.676| [INFO] train-EMA-loss: 0.451188
|2023-12-11|21:46:18.077| [INFO] infer-train-accuracy: 0.742000
|2023-12-11|21:46:18.416| [INFO] infer-valid-A-accuracy: 0.558000
|2023-12-11|21:46:18.713| [INFO] EPOCH: 028
100%|██████████| 5/5 [00:00<00:00, 12.59it/s]
|2023-12-11|21:46:19.112| [INFO] train-EMA-loss: 0.442870
|2023-12-11|21:46:19.504| [INFO] infer-train-accuracy: 0.691000
|2023-12-11|21:46:19.832| [INFO] infer-valid-A-accuracy: 0.519000
|2023-12-11|21:46:20.136| [INFO] EPOCH: 029
100%|██████████| 5/5 [00:00<00:00, 12.00it/s]
|2023-12-11|21:46:20.555| [INFO] train-EMA-loss: 0.391929
|2023-12-11|21:46:20.948| [INFO] infer-train-accuracy: 0.66900

best_score {'infer-train-accuracy': 0.803, 'infer-valid-A-accuracy': 0.636, 'infer-valid-B-accuracy': 0, 'infer-train-best_epoch': 46, 'infer-valid-A-best_epoch': 2}


100%|██████████| 5/5 [00:00<00:00, 11.83it/s]
|2023-12-11|21:46:52.358| [INFO] train-EMA-loss: 0.197974
|2023-12-11|21:46:52.741| [INFO] infer-train-accuracy: 0.781000
|2023-12-11|21:46:53.094| [INFO] infer-valid-A-accuracy: 0.506000
|2023-12-11|21:46:53.408| [INFO] EPOCH: 052
100%|██████████| 5/5 [00:00<00:00, 12.05it/s]
|2023-12-11|21:46:53.825| [INFO] train-EMA-loss: 0.180647
|2023-12-11|21:46:54.263| [INFO] infer-train-accuracy: 0.792000
|2023-12-11|21:46:54.585| [INFO] infer-valid-A-accuracy: 0.494000
|2023-12-11|21:46:54.934| [INFO] EPOCH: 053
100%|██████████| 5/5 [00:00<00:00, 11.88it/s]
|2023-12-11|21:46:55.357| [INFO] train-EMA-loss: 0.185316
|2023-12-11|21:46:55.737| [INFO] infer-train-accuracy: 0.680000
|2023-12-11|21:46:56.077| [INFO] infer-valid-A-accuracy: 0.545000
|2023-12-11|21:46:56.427| [INFO] EPOCH: 054
100%|██████████| 5/5 [00:00<00:00, 12.99it/s]
|2023-12-11|21:46:56.816| [INFO] train-EMA-loss: 0.160134
|2023-12-11|21:46:57.174| [INFO] infer-train-accuracy: 0.70200

best_score {'infer-train-accuracy': 0.893, 'infer-valid-A-accuracy': 0.636, 'infer-valid-B-accuracy': 0, 'infer-train-best_epoch': 75, 'infer-valid-A-best_epoch': 2}


100%|██████████| 5/5 [00:00<00:00, 13.03it/s]
|2023-12-11|21:47:28.135| [INFO] train-EMA-loss: 0.204511
|2023-12-11|21:47:28.511| [INFO] infer-train-accuracy: 0.921000
|2023-12-11|21:47:28.882| [INFO] infer-valid-A-accuracy: 0.506000
|2023-12-11|21:47:29.217| [INFO] EPOCH: 077
100%|██████████| 5/5 [00:00<00:00, 11.47it/s]
|2023-12-11|21:47:29.655| [INFO] train-EMA-loss: 0.062215
|2023-12-11|21:47:30.112| [INFO] infer-train-accuracy: 0.843000
|2023-12-11|21:47:30.487| [INFO] infer-valid-A-accuracy: 0.558000
|2023-12-11|21:47:30.819| [INFO] EPOCH: 078
100%|██████████| 5/5 [00:00<00:00, 12.28it/s]
|2023-12-11|21:47:31.228| [INFO] train-EMA-loss: 0.075478
|2023-12-11|21:47:31.655| [INFO] infer-train-accuracy: 0.921000
|2023-12-11|21:47:31.996| [INFO] infer-valid-A-accuracy: 0.519000
|2023-12-11|21:47:32.306| [INFO] EPOCH: 079
100%|██████████| 5/5 [00:00<00:00, 12.29it/s]
|2023-12-11|21:47:32.715| [INFO] train-EMA-loss: 0.159792
|2023-12-11|21:47:33.125| [INFO] infer-train-accuracy: 0.84300

{'infer-train-accuracy': 0.972, 'infer-valid-A-accuracy': 0.636, 'infer-valid-B-accuracy': 0, 'infer-train-best_epoch': 97, 'infer-valid-A-best_epoch': 2}


In [220]:
from sksurv.nonparametric import kaplan_meier_estimator
import matplotlib.pyplot as plt

def plot_km_curve_222(df, save_name= "test1.png"):

    plt.figure()
    color_dict = {1: 'b',
                0: 'r',
                2: 'y'}
    
    treatment_type_name = {0: 'High',
                1: 'Low',
                2: 'y'}

    for treatment_type in (0,1):

        color = color_dict[treatment_type]
        mask_treat = df[f"preds"] == treatment_type
        time_treatment, survival_prob_treatment, conf_int = kaplan_meier_estimator(
            df[f"surv_event"][mask_treat],
            df[f"surv_time"][mask_treat],
            conf_type="log-log",
        )
        
        plt.step(time_treatment, survival_prob_treatment, where="post", 
                 label=f"Risk group = {treatment_type_name[treatment_type]}", color=color)
        plt.fill_between(time_treatment, conf_int[0], conf_int[1], alpha=0.1, step="post", color=color)

        # print(time_treatment)

    plt.ylim(0, 1)
    plt.ylabel("Survival")
    plt.xlabel("time $t$")
    plt.title("KM curve on Test set")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(save_name)
    plt.show()



def get_km_plot_multi(output):
    preds= output[0][:,0]
    gt= output[1][:,0]

    median_value = np.mean(preds)
    print(np.mean(preds), np.mean(preds), np.median(preds), np.mean(preds))
    preds[preds >= median_value] = 1
    preds[preds < median_value] = 0

    output_df =pd.DataFrame([preds, gt, output[2], output[3]]).T
    output_df.columns = ("preds", "gt", "surv_event" ,"surv_time")

    output_df['surv_event'] = output_df['surv_event'].astype(bool)
    output_df['preds'] = output_df['preds'].astype(int)
    output_df['surv_event'] = output_df['surv_event'].astype(bool)

    # plot_km_curve_222(output_df)

    return output_df


def get_km_plot(output):
    # preds= output[0][:,0]
    # gt= output[1][:,0]
    orig_preds = output[-1]['infer-valid-A-pred']
    gt = output[-1]['infer-valid-A-gt']
    surv_time_list = output[-1]['infer-valid-A-surv_time_list']
    surv_event_list = output[-1]['infer-valid-A-surv_event_list']
    
    # print("orig_preds", preds)
    # print("gt", gt)
    # print("surv_time_list", surv_time_list)
    # print("surv_event_list", surv_event_list)
    # quant_1 = np.quantile(preds, 0.33)
    # quant_2 = np.quantile(preds, 0.66)
    # print(quant_1, quant_2, np.mean(preds), np.mean(preds), np.median(preds), np.mean(preds))
    # orig_preds[ (preds < quant_1) ] = 0
    # orig_preds[ (preds >= quant_1) * (preds <= quant_2) ] = 1
    # orig_preds[ (preds > quant_2) ] = 2

    preds = gt.copy()
    output_df =pd.DataFrame([preds, gt, surv_event_list, surv_time_list]).T
    output_df.columns = ("preds", "gt", "surv_event" ,"surv_time")
    output_df['surv_event'] = output_df['surv_event'].astype(bool)
    output_df['preds'] = output_df['preds'].astype(int)
    print(output_df.groupby(['preds', 'surv_event']).count()['gt'] )
    display(output_df.groupby(['preds']).mean())
    plot_km_curve_222(output_df,  save_name= "train.png")
    
    preds = orig_preds.argmax(axis=1)
    output_df =pd.DataFrame([preds, gt, surv_event_list, surv_time_list]).T
    output_df.columns = ("preds", "gt", "surv_event" ,"surv_time")
    output_df['surv_event'] = output_df['surv_event'].astype(bool)
    output_df['preds'] = output_df['preds'].astype(int)
    print(output_df.groupby(['preds', 'surv_event']).count()['gt'] )
    display(output_df.groupby(['preds']).mean())
    plot_km_curve_222(output_df, save_name= "test.png")

    return output_df

output_df = get_km_plot(output)

preds  surv_event
0      False         18
       True          20
1      False         31
       True           8
Name: gt, dtype: int64


Unnamed: 0_level_0,gt,surv_event,surv_time
preds,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.0,0.526316,4.417632
1,1.0,0.205128,9.741026


preds  surv_event
0      False          6
       True           8
1      False         43
       True          20
Name: gt, dtype: int64


Unnamed: 0_level_0,gt,surv_event,surv_time
preds,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.142857,0.571429,5.455
1,0.587302,0.31746,7.48254


# Accuracy metrics

In [None]:
pred = [[1.46648765e+00,-8.77613187e-01,-2.51502216e-01,-1.10481453e+00
,6.64071560e-01,-5.99045038e-01]
,[-1.80002856e+00,8.66671383e-01,-6.64208949e-01,1.20571077e+00
,-2.28768730e+00,1.23205519e+00]
,[-8.38510573e-01,2.11250722e-01,4.82147932e-03,5.65632761e-01
,-1.94593978e+00,6.11969411e-01]
,[1.70954013e+00,-7.06104040e-01,-6.74402267e-02,-8.06845903e-01
,1.92917514e+00,-1.58036017e+00]
,[1.23513281e+00,-6.98642850e-01,-7.90928781e-01,-6.50385141e-01
,9.64580655e-01,-5.73552132e-01]
,[-9.15723860e-01,3.66611242e-01,3.33443969e-01,4.22663540e-01
,-7.69037247e-01,1.99665681e-01]]

gt = [[1.,1.,0.,0.,0.,0.]
,[0.,0.,0.,0.,0.,1.]
,[1.,1.,0.,1.,1.,0.]
,[1.,0.,0.,0.,0.,0.]
,[0.,1.,1.,0.,0.,1.]
,[1.,1.,1.,1.,1.,0.]]


def multilabel_accuracy(y_pred, y_true):
    """
    Calculate multilabel classification metrics.

    Parameters:
    - y_true: 2D array-like or binary indicator matrix, true labels
    - y_pred: 2D array-like or binary indicator matrix, predicted labels

    Returns:
    - overall_accuracy: Subset accuracy (exact match ratio)
    - label_accuracies: List of accuracies for each individual label
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred[y_pred > 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    # Ensure the shapes are the same
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes of y_true and y_pred must be the same.")
    
    label_accuracies = np.mean(y_true == y_pred, axis=0).round(4)
    overall_accuracy = np.mean(label_accuracies).round(4)

    return overall_accuracy, label_accuracies

multilabel_accuracy(pred, gt)