# Constants


In [None]:

import pandas as pd
import os
import shutil
from os import path as osp

import torch
from tiatoolbox.utils.misc import select_device
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import joblib
import json
import glob

# from src.intensity import add_features_and_create_new_dicts

from src.featureextraction import get_cell_features, add_features_and_create_new_dicts
from src.train import stratified_split, recur_find_ext, run_once, rm_n_mkdir ,reset_logging
from src.graph_construct import create_graph_with_pooled_patch_nodes, get_pids_labels_for_key


ON_GPU = False
device = select_device(on_gpu=ON_GPU)

SEED = 5
random.seed(SEED)
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


BASEDIR = '/home/amrit/data/proj_data/MLG_project/DLBCL-Morph'


# STAIN = 'MYC'
# STAIN = 'BCL2'
STAIN = 'HE'

FIDIR = f'{BASEDIR}/outputs'
CLINPATH = f'{BASEDIR}/clinical_data_cleaned.csv'
ANNPATH = f'{BASEDIR}/annotations_clean.csv'
FEATSDIR = f'{BASEDIR}/outputs/files/{STAIN}'
FEATSCALERPATH = f"{FEATSDIR}/0_feat_scaler.npz"
PATCH_SIZE = 224
OUTPUT_SIZE = PATCH_SIZE*8

WORKSPACE_DIR = Path(BASEDIR)
# GRAPH_DIR = WORKSPACE_DIR / f"graphs{STAIN}" 
# LABELS_PATH = WORKSPACE_DIR / "graphs/0_labels.txt"


# Graph construction
# PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

LABEL_TYPE = 'multilabel' #'OS' #

GRAPHSDIR = Path(f'{BASEDIR}/graphs/{STAIN}')
LABELSPATH = f'{BASEDIR}/graphs/{STAIN}_labels.json'

NUM_EPOCHS = 100
NUM_NODE_FEATURES = 128
NCLASSES = 3

TRAIN_DIR = WORKSPACE_DIR / "training"
SPLIT_PATH = TRAIN_DIR / f"splits_{STAIN}.dat"
RUN_OUTPUT_DIR = TRAIN_DIR / f"session_{STAIN}_{datetime.now().strftime('%m_%d_%H_%M')}"
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR = RUN_OUTPUT_DIR / "model"

# Extract features: 20 + 4 minutes

In [None]:
# read annotation csv, filter, and process intensity features
df = pd.read_csv(ANNPATH)

df = df[df['stain'] == STAIN]
df['area'] = (df['xe'] - df['xs']) *  (df['ye'] - df['ys'])/10000
df = df[df['area'] >= 150]  
df = df[df['xs']  >=0 ]
df = df[df['ys']  >=0 ]
df = df[df['xe']  >=0 ]
df = df[df['ye']  >=0 ]

df = df.reset_index()
##########

###############
# add intensity features
start_index = 0
end_index = len(df.index)

datpaths = []
imgpaths = []
updatpaths = []

# for index in range(start_index, 1):
for index in range(start_index, end_index):
    df_index = df['index'][index]
    patient_id = df['patient_id'][index]
    stain = df['stain'][index]
    tma_id = df['tma_id'][index]
    unique_id = str(patient_id) + '_' + stain + '_' + str(df_index)

    img_file_name = f"{FIDIR}/images/{patient_id}/{patient_id}_{stain}_{tma_id}_{OUTPUT_SIZE}_{df_index}.png"
    dat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/0.dat"
    updat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/{unique_id}.dat"

    datpaths.append(dat_file_name)
    imgpaths.append(img_file_name)
    updatpaths.append(updat_file_name)
    


In [None]:
updatpaths[:3]

In [None]:
# 16 minutes
# shutil.rmtree(FEATSDIR)
# if not osp.exists(FEATSDIR):
#     os.makedirs(FEATSDIR)
add_features_and_create_new_dicts(datpaths, imgpaths, updatpaths)

In [None]:
# 4 minutes
gns = StandardScaler()
for featpath in tqdm(updatpaths):
    try:
        celldatadict = joblib.load(featpath)
        cellsfeats = np.array([v['intensity_feats'] for k, v in celldatadict.items()])
        gns.partial_fit(cellsfeats)
    except:
        print(featpath)

np.savez(FEATSCALERPATH, mean=gns.mean_, var=gns.var_)

# dd = np.load(FEATSCALERPATH)
# print(dd['mean'], gns.mean_)
# print(dd['var'], gns.var_)

In [None]:
FEATSCALERPATH

# Prepare graphs : 4 minutes

In [None]:
from src.graph_construct import get_pids_multilabels_for_key

df = pd.read_csv(CLINPATH)
df =df.fillna(0)
# display(df.info())


if LABEL_TYPE == 'multilabel':
    df_labels = get_pids_multilabels_for_key(df, key_list=['OS','MYC IHC', 'BCL2 IHC', 'BCL6 IHC', 'CD10 IHC', 'MUM1 IHC'], nclasses=2)
else:
    df_labels = get_pids_labels_for_key(df, key ='OS', nclasses=3)

df_labels

# Graph construction
PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

# # save paths
# featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
# featpaths = [x if "/0.dat" not in x for x in featpaths]
featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
featpaths = [x for x in featpaths if ("/0.dat" not in x) and ("/file_map.dat" not in x)]
featpaths
pids = [int(osp.basename(featpath).split('_')[0]) for featpath in featpaths]
df_featpaths = pd.DataFrame(zip(pids, featpaths), columns=['patient_id', 'featpath'])

# merge to find datapoints with graph data and labels
df_data = df_featpaths.merge(df_labels, on='patient_id')
# df_data = df_data[:12]

featpaths_data = df_data['featpath'].to_list()
# labels_data = df_data['OS_class'].to_list()
labels_data = df_data[['OS_class','MYC IHC_class', 'BCL2 IHC_class', 'BCL6 IHC_class', 'CD10 IHC_class', 'MUM1 IHC_class']].to_numpy().tolist()
display(labels_data[:5])


In [None]:


outgraphpaths_data = [f"{GRAPHSDIR}/{osp.basename(featpath).split('.')[0]}.json" for featpath in featpaths_data]

# save labels
labels_dict = {osp.basename(graphpath): label for graphpath, label in zip(outgraphpaths_data, labels_data)}
with open(LABELSPATH, 'w') as f:
    json.dump(labels_dict, f)

# read normalizer stats from file and pass to fn
dd = np.load(FEATSCALERPATH)
cell_feat_norm_stats = (dd['mean'], dd['var'])

# create final graphs data
shutil.rmtree(GRAPHSDIR)
if not osp.exists(GRAPHSDIR):
    os.makedirs(GRAPHSDIR)
    create_graph_with_pooled_patch_nodes(
        featpaths_data,
        labels_data,
        outgraphpaths_data,
        PATCH_SIZE,
        cell_feat_norm_stats=cell_feat_norm_stats,
        MIN_CELLS_PER_PATCH= MIN_CELLS_PER_PATCH,
        CONNECTIVITY_DISTANCE = CONNECTIVITY_DISTANCE
    )
    # wont work in parallel mode
    # print(np.mean(global_patch_stats), np.std(global_patch_stats))

# Create data splits

In [None]:
wsi_paths = recur_find_ext(GRAPHSDIR, [".json"])
wsi_names = [Path(v).stem for v in wsi_paths]
assert len(wsi_paths) > 0, "No files found."  # noqa: S101

len(wsi_paths) , len(wsi_names) , wsi_names[:5]

In [None]:
from src.dset import multilabel_stratified_split

NUM_FOLDS = 1
TEST_RATIO = 0.2
TRAIN_RATIO = 0.8 * 0.7
VALID_RATIO = 0.8 * 0.3

# if SPLIT_PATH and os.path.exists(SPLIT_PATH):
#     splits = joblib.load(SPLIT_PATH)
# else:
x = np.array(wsi_names)
with open(LABELSPATH, 'r') as f:
    labels_dict = json.load(f)
print(labels_dict)
y = np.array([labels_dict[wsi_name+'.json'] for wsi_name in wsi_names])
# y[np.where(y==-1)] = 0
if LABEL_TYPE == "multilabel":
    splits = multilabel_stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)
else:
    splits = stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)

joblib.dump(splits, SPLIT_PATH)


# Look at samples

In [None]:
from src.utils import load_json, rm_n_mkdir, mkdir, recur_find_ext
from src.dset import SlideGraphDataset, stratified_split, StratifiedSampler
import warnings
warnings.filterwarnings("ignore")

from torch_geometric.loader import DataLoader

nodes_preproc_func = None

def sample_data(  # noqa: C901, PLR0912, PLR0915
    dataset_dict: dict,
    GRAPH_DIR = None) -> list:
    
    for subset_name, subset in dataset_dict.items():
        ds = SlideGraphDataset(subset, mode="train", preproc=nodes_preproc_func, graph_dir=GRAPH_DIR)

        batch_sampler = None

        _loader_kwargs = {
            "num_workers": 6,
            "batch_size": 32,
        }
        # # arch_kwargs = {

        # print("subset" ,subset_name , len(subset))
        
        # batch_sampler = batch_sampler = StratifiedSampler(
        #             labels=[v[1] for v in subset],
        #             batch_size=32,
        #         )

        loader =  DataLoader(ds,
                    batch_sampler=batch_sampler,
            drop_last=subset_name == "train" and batch_sampler is None,
            shuffle=subset_name == "train" and batch_sampler is None,
            **_loader_kwargs,)
        
    return ds, loader


for split_idx, split in enumerate(splits):
    new_split = {
                "train": split["train"],
                "infer-train": split["train"],
                "infer-valid-A": split["valid"],
                "infer-valid-B": split["test"],
            }
    ds, loader = sample_data(new_split, GRAPHSDIR)

sample = ds.__getitem__(3)
display(sample)

for _step, batch_data in enumerate(loader):
    print(batch_data)
    break

In [None]:
# from src.model import SlideGraphArch

# arch_kwargs = {
#     "dim_features": NUM_NODE_FEATURES,
#     "dim_target": 3,
#     "layers": [64, 32, 32],
#     "dropout": 0.3,
#     "pooling": "mean",
#     "conv": "EdgeConv",
#     "aggr": "max",
#     "CLASSIFICATION_TYPE" : LABEL_TYPE
# }

# model = SlideGraphArch(**arch_kwargs)

# for _step, batch_data in enumerate(tqdm(loader, disable=True)):
#     device = select_device(on_gpu=False)
#     wsi_graphs = batch_data["graph"].to(device)
#     model = model.to(device)

#     # Data type conversion
#     wsi_graphs.x = wsi_graphs.x.type(torch.float32)

#     # Inference mode
#     model.eval()
#     # Do not compute the gradient (not training)
#     with torch.inference_mode():
#         wsi_output, _ = model(wsi_graphs)
#         # print("wsi_output" ,wsi_output)
#     # print("xxx")
#     break
    
# # pred = torch.nn.functional.softmax(wsi_output.type(torch.float32))
# # gt = batch_data['label'].type(torch.float32)
# # display(pred.shape , gt.shape)
# # # criterion = torch.nn.BCELoss()
# # # criterion = torch.nn.()

# # loss = torch.nn.functional.cross_entropy(pred, gt)
# # loss

# Train model

In [None]:
# # we must define the function after training/loading
# def nodes_preproc_func(node_features: np.ndarray) -> np.ndarray:
#     """Pre-processing function for nodes."""
#     return node_scaler.transform(node_features)
nodes_preproc_func = None


splits = joblib.load(SPLIT_PATH)
loader_kwargs = {
    "num_workers": 6,
    "batch_size": 32,
}
# arch_kwargs = {
#     "dim_features": NUM_NODE_FEATURES,
#     "dim_target": NCLASSES,
#     "layers": [32, 32, 16, 8],
#     "dropout": 0.3,
#     "pooling": "mean",
#     "conv": "EdgeConv",
#     "aggr": "max",
# }

if LABEL_TYPE == "OS":
    NCLASSES = 3
else:
    NCLASSES = 6
      
arch_kwargs = {
        "dim_features": NUM_NODE_FEATURES,
        "dim_target": NCLASSES,
        "layers": [64, 32, 32],
        "dropout": 0.3,
        "pooling": "mean",
        "conv": "EdgeConv",
        "aggr": "max",
        "CLASSIFICATION_TYPE" : LABEL_TYPE
}

optim_kwargs = {
    "lr": 5.0e-3,
    "weight_decay": 1.0e-4,
}

NUM_EPOCHS = 30
# if not MODEL_DIR.exists() or True:
for split_idx, split in enumerate(splits):
        new_split = {
            "train": split["train"],
            "infer-train": split["train"],
            "infer-valid-A": split["valid"],
            "infer-valid-B": split["test"],
        }
        MODEL_DIR = Path(MODEL_DIR) 
        split_save_dir = MODEL_DIR / f"{split_idx:02d}/"
        rm_n_mkdir(split_save_dir)
        reset_logging(split_save_dir)
        run_once(
            new_split,
            NUM_EPOCHS,
            save_dir=split_save_dir,
            arch_kwargs=arch_kwargs,
            loader_kwargs=loader_kwargs,
            optim_kwargs=optim_kwargs,
            on_gpu=ON_GPU,
            GRAPH_DIR=GRAPHSDIR,
            LABEL_TYPE = LABEL_TYPE
        )

# Accuracy metrics

In [10]:
pred = [[1.46648765e+00,-8.77613187e-01,-2.51502216e-01,-1.10481453e+00
,6.64071560e-01,-5.99045038e-01]
,[-1.80002856e+00,8.66671383e-01,-6.64208949e-01,1.20571077e+00
,-2.28768730e+00,1.23205519e+00]
,[-8.38510573e-01,2.11250722e-01,4.82147932e-03,5.65632761e-01
,-1.94593978e+00,6.11969411e-01]
,[1.70954013e+00,-7.06104040e-01,-6.74402267e-02,-8.06845903e-01
,1.92917514e+00,-1.58036017e+00]
,[1.23513281e+00,-6.98642850e-01,-7.90928781e-01,-6.50385141e-01
,9.64580655e-01,-5.73552132e-01]
,[-9.15723860e-01,3.66611242e-01,3.33443969e-01,4.22663540e-01
,-7.69037247e-01,1.99665681e-01]]

gt = [[1.,1.,0.,0.,0.,0.]
,[0.,0.,0.,0.,0.,1.]
,[1.,1.,0.,1.,1.,0.]
,[1.,0.,0.,0.,0.,0.]
,[0.,1.,1.,0.,0.,1.]
,[1.,1.,1.,1.,1.,0.]]


def multilabel_accuracy(y_pred, y_true):
    """
    Calculate multilabel classification metrics.

    Parameters:
    - y_true: 2D array-like or binary indicator matrix, true labels
    - y_pred: 2D array-like or binary indicator matrix, predicted labels

    Returns:
    - overall_accuracy: Subset accuracy (exact match ratio)
    - label_accuracies: List of accuracies for each individual label
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred[y_pred > 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    # Ensure the shapes are the same
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes of y_true and y_pred must be the same.")
    
    label_accuracies = np.mean(y_true == y_pred, axis=0).round(4)
    overall_accuracy = np.mean(label_accuracies).round(4)

    return overall_accuracy, label_accuracies

multilabel_accuracy(pred, gt)

(0.4722, array([0.5   , 0.1667, 0.6667, 0.6667, 0.1667, 0.6667]))