# Constants


In [25]:

import pandas as pd
import os
import shutil
from os import path as osp

import torch
from tiatoolbox.utils.misc import select_device
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import joblib
import json
import glob

# from src.intensity import add_features_and_create_new_dicts

from src.featureextraction import get_cell_features, add_features_and_create_new_dicts
from src.train import stratified_split, recur_find_ext, run_once, rm_n_mkdir ,reset_logging
from src.graph_construct import create_graph_with_pooled_patch_nodes, get_pids_labels_for_key


ON_GPU = False
device = select_device(on_gpu=ON_GPU)

SEED = 5
random.seed(SEED)
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


BASEDIR = '/home/amrit/data/proj_data/MLG_project/DLBCL-Morph'


STAIN = 'MYC'
# STAIN = 'BCL2'
# STAIN = 'HE'

FIDIR = f'{BASEDIR}/outputs'
CLINPATH = f'{BASEDIR}/clinical_data_cleaned.csv'
ANNPATH = f'{BASEDIR}/annotations_clean.csv'
FEATSDIR = f'{BASEDIR}/outputs/files/{STAIN}'
FEATSCALERPATH = f"{FEATSDIR}/0_feat_scaler.npz"
PATCH_SIZE = 224
OUTPUT_SIZE = PATCH_SIZE*8

WORKSPACE_DIR = Path(BASEDIR)
# GRAPH_DIR = WORKSPACE_DIR / f"graphs{STAIN}" 
# LABELS_PATH = WORKSPACE_DIR / "graphs/0_labels.txt"


# Graph construction
# PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

LABEL_TYPE = 'multilabel' #'OS' #

GRAPHSDIR = Path(f'{BASEDIR}/graphs/{STAIN}')
LABELSPATH = f'{BASEDIR}/graphs/{STAIN}_labels.json'

NUM_EPOCHS = 100
NUM_NODE_FEATURES = 128
NCLASSES = 3

TRAIN_DIR = WORKSPACE_DIR / "training"
SPLIT_PATH = TRAIN_DIR / f"splits_{STAIN}.dat"
# RUN_OUTPUT_DIR = TRAIN_DIR / f"session_{STAIN}_{datetime.now().strftime('%m_%d_%H_%M')}"


# Extract features: 20 + 4 minutes

In [None]:
# read annotation csv, filter, and process intensity features
df = pd.read_csv(ANNPATH)

df = df[df['stain'] == STAIN]
df['area'] = (df['xe'] - df['xs']) *  (df['ye'] - df['ys'])/10000
df = df[df['area'] >= 150]  
df = df[df['xs']  >=0 ]
df = df[df['ys']  >=0 ]
df = df[df['xe']  >=0 ]
df = df[df['ye']  >=0 ]

df = df.reset_index()
##########

###############
# add intensity features
start_index = 0
end_index = len(df.index)

datpaths = []
imgpaths = []
updatpaths = []

# for index in range(start_index, 1):
for index in range(start_index, end_index):
    df_index = df['index'][index]
    patient_id = df['patient_id'][index]
    stain = df['stain'][index]
    tma_id = df['tma_id'][index]
    unique_id = str(patient_id) + '_' + stain + '_' + str(df_index)

    img_file_name = f"{FIDIR}/images/{patient_id}/{patient_id}_{stain}_{tma_id}_{OUTPUT_SIZE}_{df_index}.png"
    dat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/0.dat"
    updat_file_name = f"{FIDIR}/files/{stain}/{patient_id}/{df_index}/{unique_id}.dat"

    datpaths.append(dat_file_name)
    imgpaths.append(img_file_name)
    updatpaths.append(updat_file_name)
    


In [None]:
updatpaths[:3]

In [None]:
# 16 minutes
# shutil.rmtree(FEATSDIR)
# if not osp.exists(FEATSDIR):
#     os.makedirs(FEATSDIR)
add_features_and_create_new_dicts(datpaths, imgpaths, updatpaths)

In [None]:
# 4 minutes
gns = StandardScaler()
for featpath in tqdm(updatpaths):
    try:
        celldatadict = joblib.load(featpath)
        cellsfeats = np.array([v['intensity_feats'] for k, v in celldatadict.items()])
        gns.partial_fit(cellsfeats)
    except:
        print(featpath)

np.savez(FEATSCALERPATH, mean=gns.mean_, var=gns.var_)

# dd = np.load(FEATSCALERPATH)
# print(dd['mean'], gns.mean_)
# print(dd['var'], gns.var_)

In [None]:
FEATSCALERPATH

# Prepare graphs : 4 minutes

In [27]:
from src.graph_construct import get_pids_multilabels_for_key

df = pd.read_csv(CLINPATH)
df =df.fillna(0)
# display(df.info())

def get_pids_multilabels_for_key(df, key_list, nclasses=3, idkey='patient_id'):
    df = df[[idkey]+ key_list]

    for key in key_list:
        if key == "OS":
            ckey = key+'_class'
            df[ckey] = int(0)
            separators = np.linspace(0, 1, nclasses+1)[0:-1]
            for isep, sep in enumerate(separators):
                sepval = df[key].quantile(sep)
                sepmask = df[key] > sepval
                df.loc[sepmask, ckey] = isep
        elif key in ['MYC IHC', 'BCL2 IHC', 'BCL6 IHC']:
            ckey = key+'_class'
            df[ckey] = int(0)
            df[ckey][df[key] >= 40] = 1
            df[ckey][df[key] < 40] = 0
        else:
            ckey = key+'_class'
            df[ckey] = df[key]
    return df

if LABEL_TYPE == 'multilabel':
    df_labels = get_pids_multilabels_for_key(df, key_list=['OS','MYC IHC', 'BCL2 IHC', 'BCL6 IHC', 'CD10 IHC', 'MUM1 IHC'], nclasses=2)
else:
    df_labels = get_pids_labels_for_key(df, key ='OS', nclasses=3)

df_labels

# Graph construction
PATCH_SIZE = 300
SKEW_NOISE = 0.0001
MIN_CELLS_PER_PATCH = 10
CONNECTIVITY_DISTANCE = 500

# # save paths
# featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
# featpaths = [x if "/0.dat" not in x for x in featpaths]
featpaths = np.sort(glob.glob(f'{FEATSDIR}/**/*.dat', recursive=True)) #np.sort(glob.glob(f'{FEATSDIR}/*.dat'))
featpaths = [x for x in featpaths if ("/0.dat" not in x) and ("/file_map.dat" not in x)]
featpaths
pids = [int(osp.basename(featpath).split('_')[0]) for featpath in featpaths]
df_featpaths = pd.DataFrame(zip(pids, featpaths), columns=['patient_id', 'featpath'])

# merge to find datapoints with graph data and labels
df_data = df_featpaths.merge(df_labels, on='patient_id')
# df_data = df_data[:12]

featpaths_data = df_data['featpath'].to_list()
# labels_data = df_data['OS_class'].to_list()
labels_data = df_data[['OS_class','MYC IHC_class', 'BCL2 IHC_class', 'BCL6 IHC_class', 'CD10 IHC_class', 'MUM1 IHC_class']].to_numpy().tolist()
display(labels_data[:5])


[[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
 [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
 [1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
 [1.0, 0.0, 1.0, 0.0, 0.0, 0.0]]

In [28]:


outgraphpaths_data = [f"{GRAPHSDIR}/{osp.basename(featpath).split('.')[0]}.json" for featpath in featpaths_data]

# save labels
labels_dict = {osp.basename(graphpath): label for graphpath, label in zip(outgraphpaths_data, labels_data)}
with open(LABELSPATH, 'w') as f:
    json.dump(labels_dict, f)

# read normalizer stats from file and pass to fn
dd = np.load(FEATSCALERPATH)
cell_feat_norm_stats = (dd['mean'], dd['var'])

# create final graphs data
shutil.rmtree(GRAPHSDIR)
if not osp.exists(GRAPHSDIR):
    os.makedirs(GRAPHSDIR)
    create_graph_with_pooled_patch_nodes(
        featpaths_data,
        labels_data,
        outgraphpaths_data,
        PATCH_SIZE,
        cell_feat_norm_stats=cell_feat_norm_stats,
        MIN_CELLS_PER_PATCH= MIN_CELLS_PER_PATCH,
        CONNECTIVITY_DISTANCE = CONNECTIVITY_DISTANCE
    )
    # wont work in parallel mode
    # print(np.mean(global_patch_stats), np.std(global_patch_stats))

  0%|          | 0/319 [00:00<?, ?it/s]

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()

100%|██████████| 319/319 [01:07<00:00,  4.74it/s]


# Create data splits

In [29]:
wsi_paths = recur_find_ext(GRAPHSDIR, [".json"])
wsi_names = [Path(v).stem for v in wsi_paths]
assert len(wsi_paths) > 0, "No files found."  # noqa: S101

len(wsi_paths) , len(wsi_names) , wsi_names[:5]

(319,
 319,
 ['13901_MYC_532',
  '13901_MYC_557',
  '13902_MYC_513',
  '13902_MYC_514',
  '13903_MYC_535'])

In [30]:
from src.dset import multilabel_stratified_split

NUM_FOLDS = 1
TEST_RATIO = 0.2
TRAIN_RATIO = 0.8 * 0.7
VALID_RATIO = 0.8 * 0.3

# if SPLIT_PATH and os.path.exists(SPLIT_PATH):
#     splits = joblib.load(SPLIT_PATH)
# else:
x = np.array(wsi_names)
with open(LABELSPATH, 'r') as f:
    labels_dict = json.load(f)
print(labels_dict)
y = np.array([labels_dict[wsi_name+'.json'] for wsi_name in wsi_names])
# y[np.where(y==-1)] = 0
if LABEL_TYPE == "multilabel":
    splits = multilabel_stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)
else:
    splits = stratified_split(x, y, TRAIN_RATIO, VALID_RATIO, TEST_RATIO, NUM_FOLDS)

joblib.dump(splits, SPLIT_PATH)


{'13901_MYC_532.json': [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], '13901_MYC_557.json': [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], '13902_MYC_513.json': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0], '13902_MYC_514.json': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0], '13903_MYC_535.json': [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], '13903_MYC_560.json': [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], '13904_MYC_539.json': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], '13904_MYC_550.json': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], '13908_MYC_527.json': [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], '13908_MYC_555.json': [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], '13911_MYC_520.json': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0], '13911_MYC_533.json': [1.0, 0.0, 0.0, 0.0, 0.0, 1.0], '13912_MYC_530.json': [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], '13913_MYC_536.json': [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], '13913_MYC_554.json': [0.0, 1.0, 0.0, 1.0, 0.0, 0.0], '13914_MYC_534.json': [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], '13914_MYC_556.json': [0.0, 0.0, 1.0, 0.0, 0.0, 1.0], '13915_MYC_516.json': [1.0, 0.0, 0.0, 1.0, 0.0, 1.0], '13915_MYC_547.json': [1.0,

['/home/amrit/data/proj_data/MLG_project/DLBCL-Morph/training/splits_MYC.dat']

# Look at samples

In [31]:
from src.utils import load_json, rm_n_mkdir, mkdir, recur_find_ext
from src.dset import SlideGraphDataset, stratified_split, StratifiedSampler
import warnings
warnings.filterwarnings("ignore")

from torch_geometric.loader import DataLoader

nodes_preproc_func = None

def sample_data(  # noqa: C901, PLR0912, PLR0915
    dataset_dict: dict,
    GRAPH_DIR = None) -> list:
    
    for subset_name, subset in dataset_dict.items():
        ds = SlideGraphDataset(subset, mode="train", preproc=nodes_preproc_func, graph_dir=GRAPH_DIR)

        batch_sampler = None

        _loader_kwargs = {
            "num_workers": 6,
            "batch_size": 32,
        }
        # # arch_kwargs = {

        # print("subset" ,subset_name , len(subset))
        
        # batch_sampler = batch_sampler = StratifiedSampler(
        #             labels=[v[1] for v in subset],
        #             batch_size=32,
        #         )

        loader =  DataLoader(ds,
                    batch_sampler=batch_sampler,
            drop_last=subset_name == "train" and batch_sampler is None,
            shuffle=subset_name == "train" and batch_sampler is None,
            **_loader_kwargs,)
        
    return ds, loader


for split_idx, split in enumerate(splits):
    new_split = {
                "train": split["train"],
                "infer-train": split["train"],
                "infer-valid-A": split["valid"],
                "infer-valid-B": split["test"],
            }
    ds, loader = sample_data(new_split, GRAPHSDIR)

sample = ds.__getitem__(3)
display(sample)

for _step, batch_data in enumerate(loader):
    print(batch_data)
    break

{'graph': Data(x=[95, 128], edge_index=[2, 506], y=[6], coordinates=[95, 2]),
 'label': tensor([1., 0., 1., 0., 0., 1.], dtype=torch.float64)}

{'graph': DataBatch(x=[2125, 128], edge_index=[2, 11016], y=[192], coordinates=[2125, 2], batch=[2125], ptr=[33]), 'label': tensor([[1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [1., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [1., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 1., 0.],
        [1., 0., 0., 0., 0., 1.],
        [1., 1., 1., 0., 0., 1.],
        [0., 0., 1., 0., 1., 1.],
        [0., 1., 1., 0., 0., 1.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 1., 0., 1.],
        [1., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1.],
        [1., 0., 1., 0., 1., 1.],
        [1., 0., 0., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1.],
        [0., 0., 1., 1., 0., 0.],
        [1., 1., 1., 0., 1., 0.],
        [0., 0., 0., 1., 1

In [None]:
# from src.model import SlideGraphArch

# arch_kwargs = {
#     "dim_features": NUM_NODE_FEATURES,
#     "dim_target": 3,
#     "layers": [64, 32, 32],
#     "dropout": 0.3,
#     "pooling": "mean",
#     "conv": "EdgeConv",
#     "aggr": "max",
#     "CLASSIFICATION_TYPE" : LABEL_TYPE
# }

# model = SlideGraphArch(**arch_kwargs)

# for _step, batch_data in enumerate(tqdm(loader, disable=True)):
#     device = select_device(on_gpu=False)
#     wsi_graphs = batch_data["graph"].to(device)
#     model = model.to(device)

#     # Data type conversion
#     wsi_graphs.x = wsi_graphs.x.type(torch.float32)

#     # Inference mode
#     model.eval()
#     # Do not compute the gradient (not training)
#     with torch.inference_mode():
#         wsi_output, _ = model(wsi_graphs)
#         # print("wsi_output" ,wsi_output)
#     # print("xxx")
#     break
    
# # pred = torch.nn.functional.softmax(wsi_output.type(torch.float32))
# # gt = batch_data['label'].type(torch.float32)
# # display(pred.shape , gt.shape)
# # # criterion = torch.nn.BCELoss()
# # # criterion = torch.nn.()

# # loss = torch.nn.functional.cross_entropy(pred, gt)
# # loss

# Train model

In [33]:
# # we must define the function after training/loading
# def nodes_preproc_func(node_features: np.ndarray) -> np.ndarray:
#     """Pre-processing function for nodes."""
#     return node_scaler.transform(node_features)
nodes_preproc_func = None


splits = joblib.load(SPLIT_PATH)
loader_kwargs = {
    "num_workers": 6,
    "batch_size": 32,
}
# arch_kwargs = {
#     "dim_features": NUM_NODE_FEATURES,
#     "dim_target": NCLASSES,
#     "layers": [32, 32, 16, 8],
#     "dropout": 0.3,
#     "pooling": "mean",
#     "conv": "EdgeConv",
#     "aggr": "max",
# }

if LABEL_TYPE == "OS":
    NCLASSES = 3
else:
    NCLASSES = 6
    
conv = "EdgeConv"
pooling = "mean"
aggr = "max"
dropout= 0.1
layers = [64, 32, 32]

arch_kwargs = {
        "dim_features": NUM_NODE_FEATURES,
        "dim_target": NCLASSES,
        "layers": layers,
        "dropout": dropout,
        "pooling": pooling,
        "conv": conv,
        "aggr": aggr,
        "CLASSIFICATION_TYPE" : LABEL_TYPE
}

RUN_OUTPUT_DIR = TRAIN_DIR / f"session_{STAIN}_{conv}_{pooling}_{aggr}_{str(dropout)}_{str(layers)}_{datetime.now().strftime('%m_%d_%H_%M')}"
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR = RUN_OUTPUT_DIR / "model"

optim_kwargs = {
    "lr": 5.0e-4,
    "weight_decay": 1.0e-4,
}

NUM_EPOCHS = 200
# if not MODEL_DIR.exists() or True:
for split_idx, split in enumerate(splits):
        new_split = {
            "train": split["train"],
            "infer-train": split["train"],
            "infer-valid-A": split["valid"],
            "infer-valid-B": split["test"],
        }
        MODEL_DIR = Path(MODEL_DIR) 
        split_save_dir = MODEL_DIR / f"{split_idx:02d}/"
        rm_n_mkdir(split_save_dir)
        reset_logging(split_save_dir)
        run_once(new_split,
            NUM_EPOCHS,
            save_dir=split_save_dir,
            arch_kwargs=arch_kwargs,
            loader_kwargs=loader_kwargs,
            optim_kwargs=optim_kwargs,
            on_gpu=ON_GPU,
            GRAPH_DIR=GRAPHSDIR,
            LABEL_TYPE = LABEL_TYPE)

|2023-11-22|18:42:51.384| [INFO] EPOCH: 000


SlideGraphArch(
  (first_h): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (nns): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=128, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (convs): ModuleList(
    (0): EdgeConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    ))
    (1): EdgeConv(nn=Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05,

  0%|          | 0/5 [00:00<?, ?it/s]

 20%|██        | 1/5 [00:00<00:00, 15.45it/s]
|2023-11-22|18:42:51.451| [INFO] train-EMA-loss: 0.989561
|2023-11-22|18:42:51.752| [INFO] infer-train-accuracy: 0.589900
|2023-11-22|18:42:52.019| [INFO] infer-valid-A-accuracy: 0.623400
|2023-11-22|18:42:52.271| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|18:42:52.512| [INFO] EPOCH: 001


best_score {'infer-train-accuracy': 0, 'infer-valid-A-accuracy': 0, 'infer-valid-B-accuracy': 0}


 20%|██        | 1/5 [00:00<00:00, 14.75it/s]
|2023-11-22|18:42:52.583| [INFO] train-EMA-loss: 0.993939
|2023-11-22|18:42:52.917| [INFO] infer-train-accuracy: 0.589900
|2023-11-22|18:42:53.169| [INFO] infer-valid-A-accuracy: 0.623400
|2023-11-22|18:42:53.419| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|18:42:53.668| [INFO] EPOCH: 002
 20%|██        | 1/5 [00:00<00:00, 15.59it/s]
|2023-11-22|18:42:53.734| [INFO] train-EMA-loss: 0.964859
|2023-11-22|18:42:54.055| [INFO] infer-train-accuracy: 0.589900
|2023-11-22|18:42:54.312| [INFO] infer-valid-A-accuracy: 0.623400
|2023-11-22|18:42:54.564| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|18:42:54.825| [INFO] EPOCH: 003
 20%|██        | 1/5 [00:00<00:00, 14.22it/s]
|2023-11-22|18:42:54.898| [INFO] train-EMA-loss: 0.957918
|2023-11-22|18:42:55.238| [INFO] infer-train-accuracy: 0.589900
|2023-11-22|18:42:55.493| [INFO] infer-valid-A-accuracy: 0.623400
|2023-11-22|18:42:55.741| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|

best_score {'infer-train-accuracy': 0.5908, 'infer-valid-A-accuracy': 0.6277, 'infer-valid-B-accuracy': 0.6172, 'infer-train-label_accuracies': [0.5, 0.8483, 0.4494, 0.6798, 0.6011, 0.4663], 'infer-train-best_epoch': 24, 'infer-valid-A-label_accuracies': [0.4935, 0.8442, 0.4675, 0.7792, 0.6623, 0.5195], 'infer-valid-A-best_epoch': 25, 'infer-valid-B-label_accuracies': [0.5, 0.875, 0.4844, 0.7031, 0.7344, 0.4062], 'infer-valid-B-best_epoch': 25}


 20%|██        | 1/5 [00:00<00:00, 16.09it/s]
|2023-11-22|18:43:21.783| [INFO] train-EMA-loss: 0.842860
|2023-11-22|18:43:22.106| [INFO] infer-train-accuracy: 0.589900
|2023-11-22|18:43:22.360| [INFO] infer-valid-A-accuracy: 0.627700
|2023-11-22|18:43:22.623| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|18:43:22.873| [INFO] EPOCH: 027
 20%|██        | 1/5 [00:00<00:00, 16.74it/s]
|2023-11-22|18:43:22.935| [INFO] train-EMA-loss: 0.836646
|2023-11-22|18:43:23.231| [INFO] infer-train-accuracy: 0.588000
|2023-11-22|18:43:23.572| [INFO] infer-valid-A-accuracy: 0.629900
|2023-11-22|18:43:23.820| [INFO] infer-valid-B-accuracy: 0.617200
|2023-11-22|18:43:24.079| [INFO] EPOCH: 028
 20%|██        | 1/5 [00:00<00:00, 16.47it/s]
|2023-11-22|18:43:24.142| [INFO] train-EMA-loss: 0.827120
|2023-11-22|18:43:24.448| [INFO] infer-train-accuracy: 0.589000
|2023-11-22|18:43:24.706| [INFO] infer-valid-A-accuracy: 0.629900
|2023-11-22|18:43:24.960| [INFO] infer-valid-B-accuracy: 0.622400
|2023-11-22|

best_score {'infer-train-accuracy': 0.6301, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.6458, 'infer-train-label_accuracies': [0.573, 0.8483, 0.5337, 0.6742, 0.6404, 0.5112], 'infer-train-best_epoch': 49, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.875, 0.5781, 0.7031, 0.7188, 0.5625], 'infer-valid-B-best_epoch': 49}


 20%|██        | 1/5 [00:00<00:00, 16.36it/s]
|2023-11-22|18:43:51.435| [INFO] train-EMA-loss: 0.754957
|2023-11-22|18:43:51.756| [INFO] infer-train-accuracy: 0.629200
|2023-11-22|18:43:52.026| [INFO] infer-valid-A-accuracy: 0.666600
|2023-11-22|18:43:52.285| [INFO] infer-valid-B-accuracy: 0.643200
|2023-11-22|18:43:52.553| [INFO] EPOCH: 052
 20%|██        | 1/5 [00:00<00:00, 16.74it/s]
|2023-11-22|18:43:52.614| [INFO] train-EMA-loss: 0.741809
|2023-11-22|18:43:52.934| [INFO] infer-train-accuracy: 0.625500
|2023-11-22|18:43:53.194| [INFO] infer-valid-A-accuracy: 0.662300
|2023-11-22|18:43:53.466| [INFO] infer-valid-B-accuracy: 0.643200
|2023-11-22|18:43:53.736| [INFO] EPOCH: 053
 20%|██        | 1/5 [00:00<00:00, 16.45it/s]
|2023-11-22|18:43:53.799| [INFO] train-EMA-loss: 0.753003
|2023-11-22|18:43:54.110| [INFO] infer-train-accuracy: 0.628300
|2023-11-22|18:43:54.375| [INFO] infer-valid-A-accuracy: 0.658000
|2023-11-22|18:43:54.665| [INFO] infer-valid-B-accuracy: 0.638000
|2023-11-22|

best_score {'infer-train-accuracy': 0.6573, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.651, 'infer-train-label_accuracies': [0.618, 0.8483, 0.5843, 0.7135, 0.6404, 0.5393], 'infer-train-best_epoch': 75, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.875, 0.5938, 0.7344, 0.6562, 0.6094], 'infer-valid-B-best_epoch': 62}


 20%|██        | 1/5 [00:00<00:00, 16.91it/s]
|2023-11-22|18:44:21.457| [INFO] train-EMA-loss: 0.677312
|2023-11-22|18:44:21.767| [INFO] infer-train-accuracy: 0.655400
|2023-11-22|18:44:22.031| [INFO] infer-valid-A-accuracy: 0.642900
|2023-11-22|18:44:22.308| [INFO] infer-valid-B-accuracy: 0.625000
|2023-11-22|18:44:22.587| [INFO] EPOCH: 077
 20%|██        | 1/5 [00:00<00:00, 16.25it/s]
|2023-11-22|18:44:22.651| [INFO] train-EMA-loss: 0.678825
|2023-11-22|18:44:22.953| [INFO] infer-train-accuracy: 0.660100
|2023-11-22|18:44:23.223| [INFO] infer-valid-A-accuracy: 0.642900
|2023-11-22|18:44:23.487| [INFO] infer-valid-B-accuracy: 0.630200
|2023-11-22|18:44:23.766| [INFO] EPOCH: 078
 20%|██        | 1/5 [00:00<00:00, 16.13it/s]
|2023-11-22|18:44:23.830| [INFO] train-EMA-loss: 0.671714
|2023-11-22|18:44:24.150| [INFO] infer-train-accuracy: 0.662900
|2023-11-22|18:44:24.462| [INFO] infer-valid-A-accuracy: 0.636400
|2023-11-22|18:44:24.732| [INFO] infer-valid-B-accuracy: 0.635400
|2023-11-22|

best_score {'infer-train-accuracy': 0.6742, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.651, 'infer-train-label_accuracies': [0.6236, 0.8596, 0.5955, 0.7247, 0.6629, 0.5787], 'infer-train-best_epoch': 99, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.875, 0.5938, 0.7344, 0.6562, 0.6094], 'infer-valid-B-best_epoch': 62}


 20%|██        | 1/5 [00:00<00:00, 16.94it/s]
|2023-11-22|18:44:51.952| [INFO] train-EMA-loss: 0.636070
|2023-11-22|18:44:52.276| [INFO] infer-train-accuracy: 0.675100
|2023-11-22|18:44:52.557| [INFO] infer-valid-A-accuracy: 0.660200
|2023-11-22|18:44:52.829| [INFO] infer-valid-B-accuracy: 0.643200
|2023-11-22|18:44:53.185| [INFO] EPOCH: 102
 20%|██        | 1/5 [00:00<00:00, 16.06it/s]
|2023-11-22|18:44:53.249| [INFO] train-EMA-loss: 0.633591
|2023-11-22|18:44:53.581| [INFO] infer-train-accuracy: 0.671300
|2023-11-22|18:44:53.868| [INFO] infer-valid-A-accuracy: 0.655900
|2023-11-22|18:44:54.132| [INFO] infer-valid-B-accuracy: 0.645800
|2023-11-22|18:44:54.428| [INFO] EPOCH: 103
 20%|██        | 1/5 [00:00<00:00, 16.28it/s]
|2023-11-22|18:44:54.492| [INFO] train-EMA-loss: 0.629223
|2023-11-22|18:44:54.809| [INFO] infer-train-accuracy: 0.668500
|2023-11-22|18:44:55.095| [INFO] infer-valid-A-accuracy: 0.655900
|2023-11-22|18:44:55.360| [INFO] infer-valid-B-accuracy: 0.643200
|2023-11-22|

best_score {'infer-train-accuracy': 0.6751, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.651, 'infer-train-label_accuracies': [0.6348, 0.8596, 0.5899, 0.7247, 0.6517, 0.5899], 'infer-train-best_epoch': 101, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.875, 0.5938, 0.7344, 0.6562, 0.6094], 'infer-valid-B-best_epoch': 62}


 20%|██        | 1/5 [00:00<00:00, 15.81it/s]
|2023-11-22|18:45:22.821| [INFO] train-EMA-loss: 0.606833
|2023-11-22|18:45:23.136| [INFO] infer-train-accuracy: 0.666700
|2023-11-22|18:45:23.399| [INFO] infer-valid-A-accuracy: 0.668800
|2023-11-22|18:45:23.656| [INFO] infer-valid-B-accuracy: 0.627600
|2023-11-22|18:45:23.956| [INFO] EPOCH: 127
 20%|██        | 1/5 [00:00<00:00, 16.14it/s]
|2023-11-22|18:45:24.020| [INFO] train-EMA-loss: 0.605965
|2023-11-22|18:45:24.343| [INFO] infer-train-accuracy: 0.666700
|2023-11-22|18:45:24.611| [INFO] infer-valid-A-accuracy: 0.662300
|2023-11-22|18:45:24.880| [INFO] infer-valid-B-accuracy: 0.627600
|2023-11-22|18:45:25.173| [INFO] EPOCH: 128
 20%|██        | 1/5 [00:00<00:00, 15.96it/s]
|2023-11-22|18:45:25.238| [INFO] train-EMA-loss: 0.602776
|2023-11-22|18:45:25.591| [INFO] infer-train-accuracy: 0.660100
|2023-11-22|18:45:25.862| [INFO] infer-valid-A-accuracy: 0.664500
|2023-11-22|18:45:26.120| [INFO] infer-valid-B-accuracy: 0.630200
|2023-11-22|

best_score {'infer-train-accuracy': 0.6751, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.651, 'infer-train-label_accuracies': [0.6348, 0.8596, 0.5899, 0.7247, 0.6517, 0.5899], 'infer-train-best_epoch': 101, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.875, 0.5938, 0.7344, 0.6562, 0.6094], 'infer-valid-B-best_epoch': 62}


 20%|██        | 1/5 [00:00<00:00, 17.61it/s]
|2023-11-22|18:45:53.420| [INFO] train-EMA-loss: 0.591368
|2023-11-22|18:45:53.734| [INFO] infer-train-accuracy: 0.667600
|2023-11-22|18:45:54.000| [INFO] infer-valid-A-accuracy: 0.664500
|2023-11-22|18:45:54.259| [INFO] infer-valid-B-accuracy: 0.638000
|2023-11-22|18:45:54.553| [INFO] EPOCH: 152
 20%|██        | 1/5 [00:00<00:00, 17.20it/s]
|2023-11-22|18:45:54.613| [INFO] train-EMA-loss: 0.591505
|2023-11-22|18:45:54.935| [INFO] infer-train-accuracy: 0.667600
|2023-11-22|18:45:55.197| [INFO] infer-valid-A-accuracy: 0.666700
|2023-11-22|18:45:55.452| [INFO] infer-valid-B-accuracy: 0.635400
|2023-11-22|18:45:55.748| [INFO] EPOCH: 153
 20%|██        | 1/5 [00:00<00:00, 16.48it/s]
|2023-11-22|18:45:55.810| [INFO] train-EMA-loss: 0.595017
|2023-11-22|18:45:56.126| [INFO] infer-train-accuracy: 0.666700
|2023-11-22|18:45:56.384| [INFO] infer-valid-A-accuracy: 0.668900
|2023-11-22|18:45:56.638| [INFO] infer-valid-B-accuracy: 0.635400
|2023-11-22|

best_score {'infer-train-accuracy': 0.6751, 'infer-valid-A-accuracy': 0.6732, 'infer-valid-B-accuracy': 0.651, 'infer-train-label_accuracies': [0.6348, 0.8596, 0.5899, 0.7247, 0.6517, 0.5899], 'infer-train-best_epoch': 101, 'infer-valid-A-label_accuracies': [0.5844, 0.8442, 0.5714, 0.7922, 0.5974, 0.6494], 'infer-valid-A-best_epoch': 49, 'infer-valid-B-label_accuracies': [0.4375, 0.8906, 0.625, 0.7344, 0.6719, 0.5469], 'infer-valid-B-best_epoch': 171}


 20%|██        | 1/5 [00:00<00:00, 16.07it/s]
|2023-11-22|18:46:24.028| [INFO] train-EMA-loss: 0.588935
|2023-11-22|18:46:24.324| [INFO] infer-train-accuracy: 0.671400
|2023-11-22|18:46:24.581| [INFO] infer-valid-A-accuracy: 0.660200
|2023-11-22|18:46:24.843| [INFO] infer-valid-B-accuracy: 0.653600
|2023-11-22|18:46:25.160| [INFO] EPOCH: 177
 20%|██        | 1/5 [00:00<00:00, 15.57it/s]
|2023-11-22|18:46:25.227| [INFO] train-EMA-loss: 0.582080
|2023-11-22|18:46:25.539| [INFO] infer-train-accuracy: 0.665700
|2023-11-22|18:46:25.800| [INFO] infer-valid-A-accuracy: 0.662400
|2023-11-22|18:46:26.056| [INFO] infer-valid-B-accuracy: 0.656200
|2023-11-22|18:46:26.346| [INFO] EPOCH: 178
 20%|██        | 1/5 [00:00<00:00, 13.38it/s]
|2023-11-22|18:46:26.423| [INFO] train-EMA-loss: 0.579418
|2023-11-22|18:46:26.809| [INFO] infer-train-accuracy: 0.665700
|2023-11-22|18:46:27.063| [INFO] infer-valid-A-accuracy: 0.662400
|2023-11-22|18:46:27.313| [INFO] infer-valid-B-accuracy: 0.658900
|2023-11-22|

# Accuracy metrics

In [None]:
pred = [[1.46648765e+00,-8.77613187e-01,-2.51502216e-01,-1.10481453e+00
,6.64071560e-01,-5.99045038e-01]
,[-1.80002856e+00,8.66671383e-01,-6.64208949e-01,1.20571077e+00
,-2.28768730e+00,1.23205519e+00]
,[-8.38510573e-01,2.11250722e-01,4.82147932e-03,5.65632761e-01
,-1.94593978e+00,6.11969411e-01]
,[1.70954013e+00,-7.06104040e-01,-6.74402267e-02,-8.06845903e-01
,1.92917514e+00,-1.58036017e+00]
,[1.23513281e+00,-6.98642850e-01,-7.90928781e-01,-6.50385141e-01
,9.64580655e-01,-5.73552132e-01]
,[-9.15723860e-01,3.66611242e-01,3.33443969e-01,4.22663540e-01
,-7.69037247e-01,1.99665681e-01]]

gt = [[1.,1.,0.,0.,0.,0.]
,[0.,0.,0.,0.,0.,1.]
,[1.,1.,0.,1.,1.,0.]
,[1.,0.,0.,0.,0.,0.]
,[0.,1.,1.,0.,0.,1.]
,[1.,1.,1.,1.,1.,0.]]


def multilabel_accuracy(y_pred, y_true):
    """
    Calculate multilabel classification metrics.

    Parameters:
    - y_true: 2D array-like or binary indicator matrix, true labels
    - y_pred: 2D array-like or binary indicator matrix, predicted labels

    Returns:
    - overall_accuracy: Subset accuracy (exact match ratio)
    - label_accuracies: List of accuracies for each individual label
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred[y_pred > 0.5] = 1
    y_pred[y_pred < 0.5] = 0

    # Ensure the shapes are the same
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes of y_true and y_pred must be the same.")
    
    label_accuracies = np.mean(y_true == y_pred, axis=0).round(4)
    overall_accuracy = np.mean(label_accuracies).round(4)

    return overall_accuracy, label_accuracies

multilabel_accuracy(pred, gt)