In [None]:
from torch.utils.data import DataLoader
import torch
from FINAL.Configs.config import *
from FINAL.Models.PJ.models import Model, Model_LAHST
from torch import nn, optim
from tqdm import tqdm
from FINAL.Data.Multi_Modal_Dataset import Multi_Modal_Dataset, tolerant_collate
import json
from collections import defaultdict
import os
import gc
from Multi_Modal_Continuous.utils import get_tokenizer
from Multi_Modal_Continuous.metrics import MyMetrics
from FINAL.loop_pj import inference
import numpy as np
import pandas as pd
import wandb
from FINAL.loop_pj import select_sequence
import torch.cuda.amp as amp
import time as timing
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch

In [None]:
def plot_gate_importance(gates, type_gates, note_idx=None, top_k=10, figsize=(12, 4), save_path=None):
    """
    Plot gating values for events enriching a particular note.

    Parameters:
    - gates: List or 1D np.array of gating values.
    - note_idx: Optional, index of the note (to annotate title).
    - top_k: Number of top gating values to annotate.
    - figsize: Size of the plot.
    - save_path: Path to save figure if needed.
    """

    gates = np.array(gates.detach().to("cpu"))
    event_ids = np.arange(len(gates))

    # Define colors
    color_map = {'Lab': 'green', 'Drug': 'orange', 'Microbio': 'red'}
    default_color = 'skyblue'
    fallback_color = 'gray'
    bar_colors = [default_color] * len(gates)

    # Identify top-k events and assign their color
    top_indices = np.argsort(gates)[-top_k:][::-1]
    used_types = set()
    for i in top_indices:
        event_type = type_gates[i]
        used_types.add(event_type)
        bar_colors[i] = color_map.get(event_type, fallback_color)

    # Plot
    plt.figure(figsize=figsize)
    plt.bar(event_ids, gates, color=bar_colors, edgecolor='black')

    # Annotate top-k indices
    for i in top_indices:
        plt.text(i, gates[i] + 0.01, f'{i}', ha='center', va='bottom', fontsize=8, rotation=90)

    # Add legend for top-k event types
    legend_handles = [Patch(color=color_map[etype], label=etype) for etype in sorted(used_types)]
    plt.legend(handles=legend_handles, title="Event Type (Top-10)", loc='upper right')

    plt.xlabel("Event Index")
    plt.ylabel("Gating Value")
    title = f"Gating Values for Note {note_idx}" if note_idx is not None else "Gating Values per Event"
    plt.title(title)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()


def get_vector_at_index(list_of_lists, event_types, index):
    current_start = 0
    j = 0
    for sublist in list_of_lists:
        num_rows = len(sublist)
        temp_event_types = [event_types[j][i] for i in range(len(event_types[j])) if event_types[j][i]!='Text']
        if index < current_start + num_rows:
            return sublist[index - current_start], temp_event_types
        current_start += num_rows
        j+=1
    raise IndexError("Index out of range")

def set_size(width, fraction=0.8):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim
def return_attn_scores(lwan, encoding, all_tokens=True, cutoffs=None):
    # encoding: Tensor of size (Nc x T) x H
    # mask: Tensor of size Nn x (Nc x T) x H
    # temporal_encoding = Nn x (N x T) x hidden_size
    T = lwan.seq_len
    if not lwan.all_tokens:
        T = 1  # only use the [CLS]-token representation
    Nc = int(encoding.shape[0] / T)
    H = lwan.hidden_size
    Nl = lwan.num_labels

    # label query: shape L, H
    # encoding: hape NcxT, H
    # query shape:  Nn, L, H
    # key shape: Nn, Nc*T, H
    # values shape: Nn, Nc*T, H
    # key padding mask: Nn, Nc*T (true if ignore)
    # output: N, L, H
    mask = torch.ones(size=(Nc, Nc * T), dtype=torch.bool).to(device=lwan.device)
    for i in range(Nc):
        mask[i, : (i + 1) * T] = False

    # only mask out at 2d, 5d, 13d and no DS to reduce computation
    # get list of cutoff indices from cutoffs dictionary

    attn_output, attn_output_weights = lwan.multiheadattn.forward(
        query=lwan.label_queries.repeat(mask.shape[0], 1, 1),
        key=encoding.repeat(mask.shape[0], 1, 1),
        value=encoding.repeat(mask.shape[0], 1, 1),
        key_padding_mask=mask,
        need_weights=True,
    )

    score = torch.sum(
        attn_output
        * lwan.label_weights.unsqueeze(0).view(
            1, lwan.num_labels, lwan.hidden_size
        ),
        dim=2,
    )
    return attn_output_weights, score



def load_token_dicts_from_json(filepath="token_dicts.json"):
    with open(filepath, "r", encoding="utf-8") as f:
        loaded = json.load(f)

    token_dicts = {}
    for key, d in loaded.items():
        if "ind2tok" in key:
            # Convert keys back to integers
            d = {int(k): v for k, v in d.items()}
        token_dicts[key] = d

    return token_dicts

def load_config(config_path):
    """Load configuration from a JSON file."""
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file {config_path} not found.")

    with open(config_path, "r") as f:
        config_data = json.load(f)

    return config_data

def view_model(model, output_param =True, output_struct = False):
    if output_param:
        total = 0
        print("Detailed Parameter Breakdown:")
        print("--------------------------------")
        for name, param in model.named_parameters():
            if param.requires_grad:
                param_count = param.numel()
                print(f"{name:60}: {param_count}")
                total += param_count
        print("--------------------------------")
        print(f"Total Trainable Parameters: {total:,}")
        return total
    if output_struct:
        print("Model Structure (Modules):")
        print("--------------------------------")
        for name, module in model.named_modules():
            if name == "":
                print(f"{type(module).__name__}: [root module]")
            else:
                print(f"{type(module).__name__}: {name}")


In [None]:
config_path = "/FINAL/Models/PJ/Saved_models/LAHST/baseline/config_MMULA_evaluate.json"

compute_all_set = True
config = load_config(config_path)
# qualitative_evaluation = config["qualitative_evaluation"]

print("loading tokens...")
token_dicts = load_token_dicts_from_json(config["event_tokens_file"])
tokenizer = get_tokenizer(config["base_checkpoint"])

val_set = Multi_Modal_Dataset(name="VALIDATION", file_path=config["file_path"], splits=config["splits"],
                               token_dicts=token_dicts,
                               mimic_dir=config["mimic_dir"], tokenizer=tokenizer,
                               saved_path=config["saved_path"])

validation_loader = DataLoader(val_set, batch_size=config["batch_size"], shuffle=False,
                               collate_fn=tolerant_collate)

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# device = torch.device("cuda" if torch.cuda.is_available() else "mps")
# model = Model(config, device=device)
model = Model_LAHST(config=config, device=device)

param = view_model(model)

checkpoint = torch.load(
    os.path.join(config["project_path"], f"Saved_models/LAHST/baseline/BEST_{config['run_name']}.pth"),
    map_location=torch.device("mps"),
    weights_only=False
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

In [None]:
do_sample = True

if do_sample:
    # evaluation metrics
    my_metrics = MyMetrics(debug=config["debug"])

    # sample = val_set[605]  # hadm == 137969
    sample = val_set[173]  # hadm == 111403
    hadm, event_type_sequence, sequences, sequences_tokenized, timestamps, labels= sample

    event_type_seq, seq, seq_tokenized, tstamp, input_ids, attention_mask, seq_ids, category_ids, cutoffs, pos_encodings, rev_encodings \
        = select_sequence(event_type_sequence, sequences, sequences_tokenized, timestamps, tokenizer,
                          mode="eval")  # to change to train

    print(input_ids.shape)

    labels = labels[: model.num_labels]
    avail_docs = seq_ids.max().item() + 1
    # note_end_chunk_ids = data["note_end_chunk_ids"]
    print(cutoffs)

    complete_sequence_output = []
    complete_gate_lists = []
    # run through data in chunks of max_chunks
    t = 0
    for i in tqdm(range(0, input_ids.shape[0], model.max_chunks)):
        # only get the document embeddings
        sequence_output, gate_lists = model(
            event_types=event_type_seq[t], seq_tokenized=seq_tokenized[t],
            input_ids=input_ids[i: i + model.max_chunks].to(
                device, dtype=torch.long
            ),
            attention_mask=attention_mask[i: i + model.max_chunks].to(
                device, dtype=torch.long
            ),
            seq_ids=seq_ids[i: i + model.max_chunks].to(
                device, dtype=torch.long
            ),
            category_ids=category_ids[i: i + model.max_chunks].to(
                device, dtype=torch.long
            ),
            cutoffs=None,  # None,  # cutoffs, #None,
            is_evaluation=True,
            # note_end_chunk_ids=note_end_chunk_ids,
        )
        complete_sequence_output.append(sequence_output)
        # complete_gate_lists.append(torch.mean(gate_lists, dim=2))
        complete_gate_lists.append(gate_lists)
        t += 1
    # concatenate the sequence output
    sequence_output = torch.cat(complete_sequence_output, dim=0)

    # run through LWAN to get the scores
    attn_output_weights, scores = return_attn_scores(model.label_attn, sequence_output, cutoffs=cutoffs)

    labels_sample = []
    for i in range(50):
        if labels[i] == 1:
            print(i)
            labels_sample.append(i)




    # visualize a heatmap of attn_output_weights[-1,l,:] for all l in 0,..,49 such that labels[l] == 1
    # viz_label = 18


    save_fig = True
    if save_fig:

        # visualize a heatmap of attn_output_weights[-1,l,:] for all l in 0,..,49 such that labels[l] == 1
        # viz_label = 18
        viz_label = 10

        plot_gates = False
        if plot_gates:
            for cutoff in cutoffs.keys():
                print(f"Cutoff {cutoff}")
                cutoff_idx = cutoffs[cutoff]
                print(f"Label {viz_label}")
                plt.figure(figsize=set_size(438.17227, fraction=0.8))  # already used
                attn_weights = attn_output_weights[cutoff_idx, viz_label, :].cpu().detach().numpy().reshape(1, -1)
                most_important_chunk_idx = np.where(attn_weights == attn_weights.max())[1][0]

                gates, types_gates = get_vector_at_index(complete_gate_lists, event_type_seq, index=most_important_chunk_idx)

                plot_gate_importance(gates, types_gates, note_idx=most_important_chunk_idx, top_k=10,
                                     save_path=f"/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/Multi_Modal_Multi_head_Residual_Attn_LN/figures/gate_weights_ch{most_important_chunk_idx}_{cutoff}.png")

                # make heatmap value range from the min to the max
                sns.heatmap(
                    attn_weights,
                    cmap="Blues",
                    vmin=attn_weights.min(),
                    vmax=attn_weights.max(),
                )
                # show lines between the notesº<
                # which are in seq_ids
                # for i in range(1, seq_ids.shape[0]):
                #     if seq_ids[i-1] != seq_ids[i]:
                #         plt.axvline(x=i, color="black")
                # show thin dotted lines at each cutoff
                for cutoff in cutoffs.keys():
                    cutoff_idx = cutoffs[cutoff]
                    plt.axvline(x=cutoff_idx + 1, color="black", linestyle=":")

                plt.show()

        for cutoff in cutoffs.keys():
            print(f"Cutoff {cutoff}")
            cutoff_idx = cutoffs[cutoff]
            print(f"Label {viz_label}")
            plt.figure(figsize=set_size(438.17227, fraction=0.68))
            attn_weights = attn_output_weights[cutoff_idx, viz_label, :].cpu().detach().numpy().reshape(1, -1)
            # make heatmap value range from the min to the max
            sns.heatmap(
                attn_weights,
                cmap="Blues",
                vmin=attn_weights.min(),
                vmax=attn_weights.max(),
            )
            # show lines between the notesº<
            # which are in seq_ids
            # for i in range(1, seq_ids.shape[0]):
            #     if seq_ids[i-1] != seq_ids[i]:
            #         plt.axvline(x=i, color="black")
            # show thin dotted lines at each cutoff
            cutoff_idxs = []
            for cutoffx in cutoffs.keys():
                if cutoffx != 'all':
                    cutoff_idx = cutoffs[cutoffx]
                    cutoff_idxs.append(cutoff_idx)
                    plt.axvline(x=cutoff_idx + 1, color="black", linestyle=":")
            plt.xlabel("Chunk position", fontsize=10)
            # remove y axis ticks
            plt.yticks([])
            # rotate x ticks and place them every 5 ticks
            plt.xticks(list(range(0, attn_weights.shape[1], 5)), list(range(0, attn_weights.shape[1], 5)), rotation=0,
                       fontsize=8)
            #
            plt.savefig(
                f"/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/Saved_models/LAHST/baseline/figures/new_attention_weights_{14}_{viz_label}_{cutoff}.png",
                bbox_inches="tight")
            plt.show()


do_quantitative_eval = True

print("doing quantitative qualitative evaluation...")
if do_quantitative_eval:
    do_quantitative_analysis = True


    if do_quantitative_analysis:
        weights_per_class = {cutoff: {c: [] for c in range(15)} for cutoff in ["2d", "5d", "13d", "noDS", 'all']}
        samples_per_class = {cutoff: {c: [] for c in range(15)} for cutoff in ["2d", "5d", "13d", "noDS", 'all']}

        # Option 1: Soft normalized gating values
        soft_contributions = {cutoff: {c: defaultdict(list) for c in range(15)} for cutoff in ["2d", "5d", "13d", "noDS", 'all']}

        # Option 2: Top-k gating values (top 10)
        topk_contributions = {cutoff: {c: defaultdict(list) for c in range(15)} for cutoff in ["2d", "5d", "13d", "noDS", 'all']}
        top_k = 10

        # Option 3: Thresholded gating values
        threshold_contributions = {cutoff: {c: defaultdict(list) for c in range(15)} for cutoff in ["2d", "5d", "13d", "noDS", 'all']}
        threshold = 0.1  # can be adjusted

        with torch.no_grad():
            with torch.cuda.amp.autocast():
                for batch_idx, (hadm, event_type_sequence, sequences, sequences_tokenized, timestamps, labels) in tqdm(enumerate(
                   validation_loader)):
                    event_type_seq, seq, seq_tokenized, tstamp, input_ids, attention_mask, seq_ids, category_ids, cutoffs \
                        = select_sequence(event_type_sequence, sequences, sequences_tokenized, timestamps, tokenizer,
                                          mode="eval", num_text_chunks=16)  # to change to train

                    if hadm == 113344:
                        complete_sequence_output = []
                        complete_gate_lists = []
                        # run through data in chunks of max_chunks
                        t = 0
                        # run through data in chunks of max_chunks
                        model.max_chunks = 16

                        for i in range(0, input_ids.shape[0], model.max_chunks):
                            # only get the document embeddings
                            sequence_output, gate_lists = model(
                                event_types=event_type_seq[t], seq_tokenized=seq_tokenized[t],
                                input_ids=input_ids[i: i + model.max_chunks].to(
                                    device, dtype=torch.long
                                ),
                                attention_mask=attention_mask[i: i + model.max_chunks].to(
                                    device, dtype=torch.long
                                ),
                                seq_ids=seq_ids[i: i + model.max_chunks].to(
                                    device, dtype=torch.long
                                ),
                                category_ids=category_ids[i: i + model.max_chunks].to(
                                    device, dtype=torch.long
                                ),
                                cutoffs=None,  # None,  # cutoffs, #None,
                                is_evaluation=True,
                                # note_end_chunk_ids=note_end_chunk_ids,
                            )
                            complete_sequence_output.append(sequence_output)

                            complete_gate_lists.append(torch.norm(gate_lists.cpu().detach(), dim=2).cpu().detach().numpy())
                            t += 1
                        # concatenate the sequence output
                        sequence_output = torch.cat(complete_sequence_output, dim=0)
                        del complete_sequence_output
                        torch.cuda.empty_cache()
                        attn_output_weights, scores = return_attn_scores(model.label_attn, sequence_output, cutoffs=cutoffs)
                        labels_sample = []
                        for i in range(50):
                            if labels[i] == 1:
                                labels_sample.append(i)
                        for cutoff in cutoffs.keys():
                            cutoff_idx = cutoffs[cutoff]
                            for l in labels_sample:
                                attn_weights = attn_output_weights[cutoff_idx, l, :].cpu().detach().numpy().reshape(1, -1)
                                cutoff_idx = cutoffs[cutoff] if cutoff != 'all' else len(sequence_output) - 1
                                for chunk in range(cutoff_idx + 1):
                                    try:
                                        c = category_ids[chunk].item()
                                        weight = attn_output_weights[cutoff_idx, l, chunk].item()

                                        gates, types_gates = get_vector_at_index(complete_gate_lists, event_type_seq,
                                                                                 index=chunk)
                                        # gates = gates.cpu().detach().numpy()
                                        # Option 1: soft normalized contribution
                                        total_gate = np.sum(gates) + 1e-8
                                        for g_val, g_type in zip(gates, types_gates):
                                            contribution = g_val / total_gate
                                            soft_contributions[cutoff][c][g_type].append(contribution)

                                        # Option 2: top-k
                                        top_indices = np.argsort(gates)[-top_k:]
                                        for idx in top_indices:
                                            g_type = types_gates[idx]
                                            g_val = gates[idx]
                                            contribution = g_val /total_gate
                                            topk_contributions[cutoff][c][g_type].append(contribution)

                                        # Option 3: threshold
                                        for g_val, g_type in zip(gates, types_gates):
                                            contribution = g_val / total_gate
                                            if contribution >= threshold:
                                                threshold_contributions[cutoff][c][g_type].append(g_val)

                                        weights_per_class[cutoff][c].append(weight)

                                    except Exception as e:
                                        print(f"[ERROR] chunk={chunk}, cutoff_idx={cutoff_idx}, label={l}")
                                        print(
                                            f"category_ids shape: {category_ids.shape}, attn_output_weights shape: {attn_output_weights.shape}")
                                        print(f"cutoff={cutoff}, class={c if 'c' in locals() else 'undefined'}")
                                        print(f"Exception: {type(e).__name__} - {e}")
                                        break  # optionally stop here to debug

                        del sequence_output, gate_lists, input_ids, attention_mask, seq_ids, category_ids
                        torch.cuda.empty_cache()
                        gc.collect()

                        #     cutoff_idx = cutoffs[cutoff]
                        #     for l in labels_sample:
                        #         attn_weights= attn_output_weights[cutoff_idx, l, :].cpu().detach().numpy().reshape(1, -1)
                        #         class_weights = {i: 0 for i in range(14)}
                        #         class_samples = {i: 0 for i in range(14)}
                        #         for chunk in range(cutoff_idx+1):
                        #             category = category_ids[chunk].item()
                        #             class_weights[category] += attn_weights[0, chunk]
                        #             class_samples[category] += 1
                        #         for c in class_weights.keys():
                        #             weights_per_class[cutoff][c].append(class_weights[c])
                        #             samples_per_class[cutoff][c].append(class_samples[c])
        json.dump(weights_per_class, open("weights_per_class.json", 'w'))
        json.dump(soft_contributions, open("soft_contributions_per_class.json", 'w'))
        json.dump(topk_contributions, open("topk_contributions_per_class.json", 'w'))
        json.dump(threshold_contributions, open("threshold_contributions_per_class.json", 'w'))

In [None]:
import os
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, List, Optional
from itertools import cycle
import pandas as pd
from collections import defaultdict
from collections import Counter
from tqdm import tqdm

def get_labels(splits):
    split = pd.read_csv(splits)
    unique_splits = split.SPLIT_50.unique()

    # Collect all ICD codes into a list for counting
    icd_counter = Counter()
    for spl in tqdm(unique_splits):
        temp = split[split["SPLIT_50"] == spl]
        print(f"fetching the codes in {spl}")
        for hadm_id in temp["HADM_ID"].unique():
            icds_hadm = eval(temp[temp["HADM_ID"] == hadm_id].absolute_code.iloc[0])
            icd_counter.update(icds_hadm)

    # Get the top 50 most common ICD codes
    # Correct
    top_50_icds = [icd for icd, _ in icd_counter.most_common(50)]

    # Build mapping dicts only for the top 50
    c2ind = {}
    ind2c = {}
    for i, icd in enumerate(top_50_icds):
        c2ind[icd] = i
        ind2c[i] = icd

    print("done")
    return c2ind, ind2c

# ---- Configuration ----
MAIN_METRICS = ["f1_macro", "f1_micro", "auc_macro", "auc_micro", "p_5"]
WINDOWS_ORDER = ["2d", "5d", "13d", "noDS", "all"]  # adjust order as you like

def compare_models_by_label(models, metric_key="f1_by_class", slice_key="all", tie_mode="share"):
    """
    Compare models per label on a per-class metric (e.g., F1 or AUC).

    Args
    ----
    models: dict[str, dict]
        model_name -> metrics dict (expects metrics[slice_key][metric_key] as list of floats)
    metric_key: str
        "f1_by_class" or "auc_by_class"
    slice_key: str
        e.g., "all", "2d", "5d", "13d", "noDS"
    tie_mode: str
        How to handle equal best scores for a label:
        - "share": give the label to all tied models
        - "skip": skip labels with ties (no winner)
        - "first": break ties by first model encountered

    Returns
    -------
    result: dict with:
        - "winners_by_label": list of tuples [(label_idx, [(model, score), ...best_ties]), ...]
        - "labels_won_by_model": dict[model] -> sorted list of label indices it wins
        - "win_counts": dict[model] -> number of labels won
        - "per_label_scores": dict[label_idx] -> dict[model] -> score
    """
    # Gather series and check max_len
    series = {}
    max_len = 0
    for name, m in models.items():
        arr = np.array(m[slice_key][metric_key], dtype=float)
        series[name] = arr
        max_len = max(max_len, len(arr))

    winners_by_label = []
    labels_won_by_model = defaultdict(list)
    per_label_scores = {}

    model_names = list(series.keys())

    for lbl in range(max_len):
        # collect available scores for this label
        scores_here = {}
        for name, arr in series.items():
            if lbl < len(arr) and np.isfinite(arr[lbl]):
                scores_here[name] = float(arr[lbl])

        if not scores_here:
            continue  # no data for this label across all models

        per_label_scores[lbl+1] = scores_here  # 1-based label index for readability
        # find best score
        best_score = max(scores_here.values())
        best_models = [n for n, s in scores_here.items() if np.isclose(s, best_score, rtol=1e-9, atol=1e-12)]

        if tie_mode == "skip" and len(best_models) > 1:
            continue
        if tie_mode == "first" and len(best_models) > 1:
            best_models = [model_names[0] if model_names[0] in best_models else best_models[0]]

        winners_by_label.append((lbl+1, [(m, scores_here[m]) for m in best_models]))

        # record wins
        for m in best_models:
            labels_won_by_model[m].append(lbl+1)

    # counts
    win_counts = {m: len(sorted(labels)) for m, labels in labels_won_by_model.items()}

    # sort label lists for consistency
    labels_won_by_model = {m: sorted(labels) for m, labels in labels_won_by_model.items()}

    return {
        "winners_by_label": winners_by_label,
        "labels_won_by_model": labels_won_by_model,
        "win_counts": win_counts,
        "per_label_scores": per_label_scores,
    }

def load_metrics_folder(folder: str) -> Dict[str, Dict[str, Dict[str, Any]]]:
    """
    Load model metrics from JSON files in a folder.
    Expected structure per file:
      {
        "2d": {"f1_macro": ..., "f1_micro": ..., ...},
        "5d": {...}, "13d": {...}, "noDS": {...}, "all": {...}
      }
    Returns: {model_name: metrics_dict}
    model_name = file stem (filename without extension)
    """
    result = {}
    for path in glob.glob(os.path.join(folder, "*.json")):
        name = os.path.splitext(os.path.basename(path))[0]
        with open(path, "r") as f:
            data = json.load(f)
        result[name] = data
    return result

def plot_labels_by_model(models, slice_key="all"):
    """
    models: dict[str, dict] mapping model_name -> metrics dict
            expects metrics[slice_key] with keys 'f1_by_class' and 'auc_by_class'
            optionally 'f1_macro' and 'auc_macro' for horizontal reference lines
    slice_key: one of 'all', '2d', '5d', '13d', 'noDS', etc.
    """

    # --- Collect and align lengths ---
    f1_series = {}
    auc_series = {}
    max_len = 0
    for name, metrics in models.items():
        f1 = metrics[slice_key]["f1_by_class"]
        auc = metrics[slice_key]["auc_by_class"]
        f1_series[name] = np.array(f1, dtype=float)
        auc_series[name] = np.array(auc, dtype=float)
        max_len = max(max_len, len(f1))

    labels = np.arange(1, max_len + 1)

    # Nice distinct markers to help separate models
    marker_cycler = cycle(["o", "s", "^", "D", "P", "X", "v", "<", ">", "*"])

    # --- F1 plot ---
    plt.figure(figsize=(12, 6))
    for name, f1 in f1_series.items():
        x = labels[:len(f1)]
        marker = next(marker_cycler)
        plt.scatter(x, f1, alpha=0.7, label=name, marker=marker)

    # Horizontal macro average lines if available
    for name, metrics in models.items():
        macro = metrics[slice_key].get("f1_macro", None)
        if macro is not None:
            plt.axhline(macro, linestyle="--", linewidth=0.8, alpha=0.5)

    plt.xlabel("Class index")
    plt.ylabel("F1 score")
    plt.title(f"Per-class F1 scores ({slice_key})")
    plt.ylim(0, 1.05)
    plt.grid(alpha=0.3)
    plt.legend(title="Model", ncols=2, fontsize=9)
    plt.tight_layout()
    plt.show()

    # reset marker cycle for AUC plot
    marker_cycler = cycle(["o", "s", "^", "D", "P", "X", "v", "<", ">", "*"])

    # --- AUC plot ---
    plt.figure(figsize=(12, 6))
    for name, auc in auc_series.items():
        x = labels[:len(auc)]
        marker = next(marker_cycler)
        plt.scatter(x, auc, alpha=0.7, label=name, marker=marker)

    for name, metrics in models.items():
        macro = metrics[slice_key].get("auc_macro", None)
        if macro is not None:
            plt.axhline(macro, linestyle="--", linewidth=0.8, alpha=0.5)

    plt.xlabel("Class index")
    plt.ylabel("AUC score")
    plt.title(f"Per-class AUC scores ({slice_key})")
    plt.ylim(0, 1.05)
    plt.grid(alpha=0.3)
    plt.legend(title="Model", ncols=2, fontsize=9)
    plt.tight_layout()
    plt.show()
def plot_grouped_bars(
    models_to_metrics: Dict[str, Dict[str, Dict[str, Any]]],
    metrics: Optional[List[str]] = None,
    windows_order: Optional[List[str]] = None,
    save_dir: Optional[str] = None,
    show: bool = True,
):
    """
    Create one bar chart per metric. X-axis = model names; grouped by time windows.
    models_to_metrics: {model_name: {window: {metric: value, ...}, ...}, ...}
    """
    if metrics is None:
        metrics = MAIN_METRICS
    if windows_order is None:
        # determine a stable window order present in the data
        present = set()
        for m in models_to_metrics.values():
            present |= set(m.keys())
        windows_order = [w for w in WINDOWS_ORDER if w in present] + sorted(present - set(WINDOWS_ORDER))

    model_names = list(models_to_metrics.keys())
    n_models = len(model_names)
    n_windows = len(windows_order)
    x = np.arange(n_models)
    bar_width = 0.8 / max(1, n_windows)

    for metric in metrics:
        plt.figure(figsize=(max(8, n_models * 1.2), 5))

        # Plot each window as a bar group
        for i, window in enumerate(windows_order):
            vals = []
            for model_name in model_names:
                # Safely fetch the metric value (NaN if missing)
                val = np.nan
                if window in models_to_metrics[model_name]:
                    val = models_to_metrics[model_name][window].get(metric, np.nan)
                vals.append(val)

            # Handle NaNs by replacing with 0 for plotting, and mark in labels if you want.
            vals_np = np.array(vals, dtype=float)
            plt.bar(x + (i - n_windows/2) * bar_width + bar_width/2, np.nan_to_num(vals_np, nan=0.0),
                    width=bar_width, label=window)

        plt.title(metric)
        plt.xticks(x, model_names, rotation=20, ha="right")
        plt.ylabel(metric)
        plt.legend(title="Window", ncols=min(n_windows, 4))
        plt.tight_layout()

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            out_path = os.path.join(save_dir, f"{metric}.png")
            plt.savefig(out_path, dpi=150)

        if show:
            plt.show()
        else:
            plt.close()


def plot_grouped_by_time(
        models_to_metrics,
        metrics=("f1_macro", "f1_micro", "auc_macro", "auc_micro", "p_5", "LRAP"),
        windows_order=("2d", "5d", "13d", "noDS", "all"),
        decimals=2,
        save_dir=None,
        show=True,
):
    """
    One figure per metric.
    X-axis: time windows. Bars within each group: models (different colors).
    Text labels: value printed on each bar.
    """
    model_names = list(models_to_metrics.keys())
    n_models = len(model_names)
    windows_order = [w for w in windows_order if any(w in m for m in models_to_metrics.values())]
    n_windows = len(windows_order)

    x = np.arange(n_windows)
    bar_width = 0.8 / max(1, n_models)

    def _val(mname, window, metric):
        try:
            return models_to_metrics[mname][window].get(metric, np.nan)
        except KeyError:
            return np.nan

    def _annotate(ax, rects, values, decimals=3):
        for rect, v in zip(rects, values):
            if np.isnan(v):
                continue
            height = rect.get_height()
            ax.text(
                rect.get_x() + rect.get_width() / 2.0,
                height + 0.01,  # little padding above bar
                f"{v:.{decimals}f}",
                ha="center", va="bottom", fontsize=9, rotation=0
            )

    for metric in metrics:
        fig, ax = plt.subplots(figsize=(max(8, n_windows * 1.6), 5))

        max_val = 0.0
        for i, mname in enumerate(model_names):
            vals = [_val(mname, w, metric) for w in windows_order]
            vals_np = np.array(vals, dtype=float)
            max_val = max(max_val, np.nanmax(np.nan_to_num(vals_np, nan=0.0)))

            rects = ax.bar(
                x + (i - n_models / 2) * bar_width + bar_width / 2,
                np.nan_to_num(vals_np, nan=0.0),
                width=bar_width,
                label=mname
            )
            _annotate(ax, rects, vals_np, decimals=decimals)

        # Titles and labels
        ax.set_title(f"Performance on {metric}", fontsize=14, fontweight="bold")
        ax.set_xticks(x, windows_order)
        ax.set_xlabel("Time Window")
        ax.set_ylabel(metric)

        # Legend outside
        ax.legend(title="Model", ncols=1, bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)

        # Add grid and headroom
        ax.set_ylim(0, max(1.0 if max_val <= 1.0 else max_val, max_val * 1.12))
        ax.grid(axis="y", linestyle="--", alpha=0.4)

        fig.tight_layout()

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            fig.savefig(os.path.join(save_dir, f"grouped_by_time__{metric}.png"),
                        dpi=150, bbox_inches="tight")

        if show:
            plt.show()
        else:
            plt.close(fig)
# ---------- Example usage ----------
if __name__ == "__main__":
    # Option A) Load from files in a folder (each file = one model's metrics)
    # models = load_metrics_folder("./metrics_runs")

    # Option B) Build from in-memory dicts (using your example for one model)
    # Add as many models as you want in this dict (modelA, modelB, ...)

    with open(
            "/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/LAHST/baseline/TEST_MMULA_evaluate.json") as f:
        modelA = json.load(f)

    with open(
            "/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/LAHST_Hierarchy/TEST_MMULA_evaluate.json") as f:
        modelB = json.load(f)

    with open(
            "/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/Multi_Modal_GATING/Multi_Modal_LAHST_Active_LN/TEST_Multi_modal.json") as f:
        modelC = json.load(f)

    with open(
            "/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/Multi_Modal_Multi_head_Residual_Attn_LN/TEST_Multi_modal.json") as f:
        modelD = json.load(f)

    with open(
            "/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/results/HLAST_MH_Residual_Attn_LN/TEST_MMULA_evaluate.json") as f:
        modelE = json.load(f)


    models = {
        "LAHST": modelA,
        # "LAHST_H": modelB,
        "Multi_MG": modelC,
        "Multi_MMH": modelD,
        #"HLAST++": modelE,

        # "ModelC": load from file or construct similarly...
    }

    c2ind, ind2c = get_labels("/Users/p-a/PycharmProjects/ICD_Coding/Multi_Modal_Continuous/Data/splits/caml_splits.csv")


    plot_labels_by_model(models, slice_key="2d")
    plot_labels_by_model(models, slice_key="5d")
    plot_labels_by_model(models, slice_key="13d")
    plot_labels_by_model(models, slice_key="noDS")
    plot_labels_by_model(models, slice_key="all")

    # --------- Example usage ----------
    result = compare_models_by_label(models, metric_key="f1_by_class", slice_key="all", tie_mode="share")
    print("Win counts all:", result["win_counts"])
    print("Labels won by model all:", result["labels_won_by_model"])
    # # Show first 10 labels’ winners:
    print(result["winners_by_label"][:10])

    # To compare by AUC:
    result_auc = compare_models_by_label(models, metric_key="auc_by_class", slice_key="all", tie_mode="share")

    # ------ 2days ----------
    result = compare_models_by_label(models, metric_key="f1_by_class", slice_key="2d", tie_mode="share")
    print("Win counts 2d:", result["win_counts"])
    print("Labels won by model 2d:", result["labels_won_by_model"])
    # # Show first 10 labels’ winners:
    print(result["winners_by_label"][:10])

    # To compare by AUC:
    result_auc = compare_models_by_label(models, metric_key="auc_by_class", slice_key="2d", tie_mode="share")

    # ------ 5days ----------
    result = compare_models_by_label(models, metric_key="f1_by_class", slice_key="5d", tie_mode="share")
    print("Win counts 5d:", result["win_counts"])
    print("Labels won by model 5d:", result["labels_won_by_model"])
    # # Show first 10 labels’ winners:
    print(result["winners_by_label"][:10])

    # To compare by AUC:
    result_auc = compare_models_by_label(models, metric_key="auc_by_class", slice_key="5d", tie_mode="share")

    # ------ 13days ----------
    result = compare_models_by_label(models, metric_key="f1_by_class", slice_key="13d", tie_mode="share")
    print("Win counts 13d:", result["win_counts"])
    print("Labels won by model 13d:", result["labels_won_by_model"])
    # # Show first 10 labels’ winners:
    print(result["winners_by_label"][:10])

    # To compare by AUC:
    result_auc = compare_models_by_label(models, metric_key="auc_by_class", slice_key="13d", tie_mode="share")

    # ------ noDS ----------
    result = compare_models_by_label(models, metric_key="f1_by_class", slice_key="noDS", tie_mode="share")
    print("Win counts noDS:", result["win_counts"])
    print("Labels won by model noDS:", result["labels_won_by_model"])
    # # Show first 10 labels’ winners:
    print(result["winners_by_label"][:10])

    # To compare by AUC:
    result_auc = compare_models_by_label(models, metric_key="auc_by_class", slice_key="noDS", tie_mode="share")

    plot_grouped_bars(
        models_to_metrics=models,
        metrics=["f1_macro", "f1_micro", "auc_macro", "auc_micro", "p_5", "LRAP"],
        windows_order=["2d", "5d", "13d", "noDS", "all"],  # choose any subset/order
        save_dir=None,   # e.g., "./plots"
        show=True
    )
    plot_grouped_by_time(
        models_to_metrics=models,
        metrics=("f1_macro", "f1_micro", "auc_macro", "auc_micro", "p_5", "LRAP"),
        windows_order=("2d", "5d", "13d", "noDS", "all"),
        decimals=2,
        save_dir=None,  # or a folder path like "./plots_time_grouped"
        show=True
    )