In [None]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
import os

import numpy as np
import pandas as pd

from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt

import io

import ipywidgets as widgets
from PIL import Image
import io

import plotly.graph_objects as go
import plotly.express as px


import sys
sys.path.append("/mnt/raid/C1_ML_Analysis/source/us-famli-pl/src/")
from nets.classification import RopeEffnetV2s
from loaders.ultrasound_dataset import USDataModuleBlindSweepWTag, USDatasetBlindSweepWTag

from tqdm.notebook import tqdm
import SimpleITK as sitk

def _frame_to_png_bytes(frame2d: np.ndarray) -> bytes:
    """Convert a (C,H,W) frame (any numeric dtype) to PNG bytes for ipywidgets.Image."""
    f = np.asarray(frame2d)
    # Handle NaNs/infs safely
    f = np.nan_to_num(f, nan=0.0, posinf=0.0, neginf=0.0)

    # Normalize to uint8
    fmin = float(f.min())
    fmax = float(f.max())
    if fmax > fmin:
        u8 = ((f - fmin) / (fmax - fmin) * 255.0).astype(np.uint8)
    else:
        u8 = np.zeros_like(f, dtype=np.uint8)

    img = Image.fromarray(u8, mode="RGB")  # RGB
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()


def visualize_sequence(
    images: np.ndarray,
    scores: np.ndarray,
    scores_pred: np.ndarray,
    *,
    fps: float = 10.0,
    title: str = "",
):
    """
    Jupyter-stable interactive viewer:
      - Left: grayscale frame shown via ipywidgets.Image
      - Right: Plotly line plots for GT + Pred with moving markers + vertical line
      - Controls: Play + Slider

    Parameters
    ----------
    images : np.ndarray
        Shape (T,H,W) or (C,T,H,W) or (T,H,W,C) supported if you tweak below.
        (Your code uses a helper _frame_to_png_bytes(frame) that expects a single frame.)
    scores : array-like
        Shape (T,), GT score per frame.
    scores_pred : array-like
        Shape (T,), predicted score per frame.
    """
    images = np.asarray(images)
    if images.ndim != 4:
        raise ValueError(f"`images` must have shape (T,H,W,C) (per your current code). Got {images.shape}")

    # Your original code says (C,T,H,W) but then unpacks as (T,H,W,C).
    # Keeping your current behavior: images.shape == (T,H,W,C)
    T, H, W, C = images.shape

    scores_np = np.asarray(scores).reshape(-1)
    if scores_np.shape[0] != T:
        raise ValueError(f"`scores` length must match T={T}. Got {scores_np.shape[0]}")

    scores_pred_np = np.asarray(scores_pred).reshape(-1)
    if scores_pred_np.shape[0] != T:
        raise ValueError(f"`scores_pred` length must match T={T}. Got {scores_pred_np.shape[0]}")

    # --- Widgets ---
    slider = widgets.IntSlider(value=0, min=0, max=T - 1, step=1, description="Frame", continuous_update=True)
    play = widgets.Play(
        value=0, min=0, max=T - 1, step=1,
        interval=int(1000 / max(fps, 1e-6)),
        description="Play",
    )
    widgets.jslink((play, "value"), (slider, "value"))

    # Image display (PNG bytes)
    img_w = widgets.Image(value=_frame_to_png_bytes(images[0]), format="png")

    # Metadata panel
    meta = widgets.HTML()
    def _meta_html(i: int) -> str:
        base = (
            f"<b>frame:</b> {i}"
            f"<br><b>score (gt):</b> {scores_np[i]:.3f}"
            f"<br><b>score (pred):</b> {scores_pred_np[i]:.3f}"
        )
        return base

    meta.value = _meta_html(0)

    # Plotly FigureWidget
    x = np.arange(T)
    fig = go.FigureWidget()

    # Lines
    fig.add_scatter(x=x, y=scores_np,      mode="lines", name="score (gt)")
    fig.add_scatter(x=x, y=scores_pred_np, mode="lines", name="score (pred)")

    # Current markers (one for each series)
    fig.add_scatter(x=[0], y=[float(scores_np[0])],      mode="markers", name="current (gt)")
    fig.add_scatter(x=[0], y=[float(scores_pred_np[0])], mode="markers", name="current (pred)")

    # Add max predicted score line
    max_pred_idx = int(np.argmax(scores_pred_np))
    fig.add_scatter(
        x=[max_pred_idx],
        y=[float(scores_pred_np[max_pred_idx])],
        mode="markers",
        name="max score (pred)",
        marker=dict(color="green", size=10, symbol="x"),
    )

    # y-range for the vertical line: cover both series
    y_min = float(np.min([scores_np.min(), scores_pred_np.min()]))
    y_max = float(np.max([scores_np.max(), scores_pred_np.max()]))

    fig.update_layout(
        title=title,
        xaxis_title="Frame",
        yaxis_title="Score",
        margin=dict(l=40, r=10, t=40, b=40),
        shapes=[
            dict(
                type="line",
                x0=0, x1=0,
                y0=y_min, y1=y_max,
                xref="x", yref="y",
                line=dict(width=2, dash="dash"),
            )
        ],
    )

    def _update(i: int):
        # Update image
        img_w.value = _frame_to_png_bytes(images[i])

        # Update markers + vertical line
        with fig.batch_update():
            # marker traces are indices 2 and 3
            fig.data[2].x = (i,)
            fig.data[2].y = (float(scores_np[i]),)

            fig.data[3].x = (i,)
            fig.data[3].y = (float(scores_pred_np[i]),)

            # keep line spanning both series ranges (or recompute if you want dynamic)
            fig.layout.shapes[0].update(x0=i, x1=i, y0=y_min, y1=y_max)

        # Update metadata
        meta.value = _meta_html(i)

    slider.observe(lambda ch: _update(ch["new"]), names="value")

    controls = widgets.HBox([play, slider])
    left = widgets.VBox([meta, img_w])
    right = widgets.VBox([fig])
    ui = widgets.VBox([controls, widgets.HBox([left, right])])

    return ui

In [None]:
import numpy as np
import imageio.v3 as iio
import matplotlib.pyplot as plt

def _to_uint8_rgb_per_frame(frame: np.ndarray) -> np.ndarray:
    """
    Convert a single frame (H,W) or (H,W,C) to uint8 RGB using per-frame min/max,
    matching the spirit of your widget normalization.
    """
    f = np.asarray(frame)
    f = np.nan_to_num(f, nan=0.0, posinf=0.0, neginf=0.0)

    if f.ndim == 2:
        f = f[..., None]  # (H,W,1)
    if f.shape[-1] == 1:
        f = np.repeat(f, 3, axis=-1)
    elif f.shape[-1] != 3:
        raise ValueError(f"Expected 1 or 3 channels, got {f.shape[-1]}")

    fmin = float(f.min())
    fmax = float(f.max())
    if fmax > fmin:
        u8 = ((f - fmin) / (fmax - fmin) * 255.0).astype(np.uint8)
    else:
        u8 = np.zeros_like(f, dtype=np.uint8)

    return u8  # (H,W,3) uint8


def export_widget_like_mp4(
    images: np.ndarray,
    scores: np.ndarray,
    scores_pred: np.ndarray,
    out_mp4: str,
    *,
    fps: float = 10.0,
    title: str = "Sequence",
    show_filename: bool = False,
    filenames=None,
    dpi: int = 120,
):
    """
    Export a widget-like video:
      left: image + text metadata
      right: plot (gt + pred) + current markers + vertical line + max pred marker

    images: (T,H,W,C) or (T,H,W) or (T,H,W,1) or (T,H,W,3)
    scores, scores_pred: (T,)
    filenames: optional (T,)
    """
    imgs = np.asarray(images)
    if imgs.ndim == 3:  # (T,H,W)
        imgs = imgs[..., None]
    if imgs.ndim != 4:
        raise ValueError(f"images must be (T,H,W) or (T,H,W,C). Got {imgs.shape}")

    T, H, W, C = imgs.shape
    scores = np.asarray(scores).reshape(-1)
    scores_pred = np.asarray(scores_pred).reshape(-1)
    if scores.shape[0] != T or scores_pred.shape[0] != T:
        raise ValueError(f"scores and scores_pred must have length T={T}. Got {scores.shape[0]} and {scores_pred.shape[0]}")

    if show_filename:
        if filenames is None:
            raise ValueError("show_filename=True requires filenames=")
        filenames = np.asarray(filenames).reshape(-1)
        if filenames.shape[0] != T:
            raise ValueError(f"filenames must have length T={T}. Got {filenames.shape[0]}")

    x = np.arange(T)
    y_min = float(np.min([scores.min(), scores_pred.min()]))
    y_max = float(np.max([scores.max(), scores_pred.max()]))
    max_pred_idx = int(np.argmax(scores_pred))

    # Video writer settings
    # Note: imageio uses ffmpeg; install imageio-ffmpeg if needed.
    frames = []

    # Create one figure and update it per frame (faster than recreating).
    fig = plt.figure(figsize=(12, 5), dpi=dpi)
    gs = fig.add_gridspec(1, 2, width_ratios=[1.0, 1.6])

    ax_img = fig.add_subplot(gs[0, 0])
    ax_plot = fig.add_subplot(gs[0, 1])

    # --- Initialize left panel ---
    ax_img.axis("off")
    im_artist = ax_img.imshow(_to_uint8_rgb_per_frame(imgs[0]), interpolation="nearest")
    meta_text = ax_img.text(
        0.02, 0.98, "",
        transform=ax_img.transAxes,
        va="top", ha="left",
        fontsize=10,
        bbox=dict(boxstyle="round,pad=0.3", alpha=0.8)
    )

    # --- Initialize right panel plot ---
    ax_plot.set_title(title)
    ax_plot.set_xlabel("Frame")
    ax_plot.set_ylabel("Score")
    ax_plot.set_xlim(0, T - 1)
    ax_plot.set_ylim(y_min, y_max)

    (line_gt,) = ax_plot.plot(x, scores, label="score (gt)")
    (line_pred,) = ax_plot.plot(x, scores_pred, label="score (pred)")

    # current markers
    (cur_gt,) = ax_plot.plot([0], [float(scores[0])], marker="o", linestyle="None", label="current (gt)")
    (cur_pred,) = ax_plot.plot([0], [float(scores_pred[0])], marker="o", linestyle="None", label="current (pred)")

    # max pred marker
    ax_plot.plot([max_pred_idx], [float(scores_pred[max_pred_idx])], marker="x", linestyle="None", markersize=10, label="max score (pred)")

    # vertical line
    vline = ax_plot.axvline(0, linestyle="--")

    ax_plot.legend(loc="best")

    def _meta(i: int) -> str:
        s = f"frame: {i}\nscore (gt): {scores[i]:.3f}\nscore (pred): {scores_pred[i]:.3f}"
        if show_filename:
            s += f"\nfile: {filenames[i]}"
        return s

    meta_text.set_text(_meta(0))
    fig.tight_layout()

    # Render loop
    for i in range(T):
        im_artist.set_data(_to_uint8_rgb_per_frame(imgs[i]))
        meta_text.set_text(_meta(i))

        cur_gt.set_data([i], [float(scores[i])])
        cur_pred.set_data([i], [float(scores_pred[i])])
        vline.set_xdata([i, i])

        # draw -> RGB array
        fig.canvas.draw()
        w, h = fig.canvas.get_width_height()
        buf = np.asarray(fig.canvas.buffer_rgba())
        rgb = buf[..., :3].copy()  # drop alpha channel
        frames.append(rgb)

    plt.close(fig)

    frames_np = np.stack(frames, axis=0)  # (T,H,W,3)
    iio.imwrite(out_mp4, frames_np, fps=fps, codec="libx264", pixelformat="yuv420p")
    return out_mp4


In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'

dev_id = 0
device = torch.device(f'cuda:{dev_id}' if torch.cuda.is_available() else 'cpu')

print(device)

model_fn = os.path.join(mount_point, 'train_output/classification/RopeEffnet/', 'v1.8/epoch=44-val_select=0.769.ckpt') # Best one -> 0.14, 0.12, 0.12, 0.08, 0.04]
version = 'v1.8'


model = RopeEffnetV2s.load_from_checkpoint(model_fn, map_location=device)

In [None]:
model.hparams

In [None]:

params = {
    "csv_train": os.path.join(mount_point, "CSV_files/efw_2025-11-12_train.csv"),
    "csv_valid": os.path.join(mount_point, "CSV_files/efw_2025-11-12_val.csv"),
    "csv_test": os.path.join(mount_point, "CSV_files/efw_2025-11-12_test.csv"),
    "mount_point": mount_point,
    "batch_size": 1,
    "prefetch_factor": 1,
    "num_frames": -1,
    "img_column": "file_path",
    "tag_column": "tag",
    "ga_column": None,
    "id_column": None,
    "frame_column": None,
    "class_column": None,
    "presentation_column": None,
    "efw_column": None,
    "max_sweeps": -1,
    "csv_train_ac": None,
    "num_frames_val": -1,
    "num_frames_test": -1,
    "num_workers": 1,
    "drop_last": False,
}

dm = USDataModuleBlindSweepWTag(**params)
dm.setup()


In [None]:
def run_test(dm, model, dl, ds, output_csv):

    if os.path.exists(output_csv):
        print(f"Predictions file {output_csv} already exists. Loading it.")
        df = pd.read_csv(output_csv)
        return df

    df_scores = []
    with torch.no_grad():
        for idx, X_d in tqdm(enumerate(dl), total=len(dl)):
            file_path = ds.df.iloc[idx]['file_path']
            logits = model(X_d['img'].permute(0, 2, 1, 3, 4).to(device))  # [B, N, C]
            probs = torch.softmax(logits, dim=-1)
            levels = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0], device=probs.device)
            score = (probs * levels).sum(dim=-1)  # [B, N]

            df = pd.DataFrame({
                'frame_index': list(np.arange(X_d['img'].shape[2])),
                'file_path': file_path,
                'score_pred': score.squeeze().cpu().numpy(),
                'score': np.zeros_like(score.squeeze().cpu().numpy()),  # Placeholder for GT score if needed
            })
            df_scores.append(df)    
    df = pd.concat(df_scores, ignore_index=True)

    if not os.path.basename(model_fn) == "last.ckpt":
        df.to_csv(output_csv, index=False)
        print(f"Saved predictions to {output_csv}")
    return df

test_dl = dm.test_dataloader()
test_ds = dm.test_ds
df_test_pred = run_test(dm, model, test_dl, test_ds, output_csv=params['csv_test'].replace('.csv', f'_rope{version}_predictions.csv'))
df_test_pred = df_test_pred.merge(df_test_pred.groupby("file_path").agg(pred_max=("score_pred","max")).reset_index(), on="file_path")

In [None]:

train_df = pd.read_csv(params['csv_train'])
train_ds = USDatasetBlindSweepWTag(train_df, mount_point=mount_point, img_column=params['img_column'], tag_column=params['tag_column'], id_column=None, transform=dm.test_transform, max_sweeps=-1, num_frames=-1)
train_dl = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)

df_train_pred = run_test(dm, model, train_dl, train_ds, output_csv=params['csv_train'].replace('.csv', f'_rope{version}_predictions.csv'))
df_train_pred = df_train_pred.merge(train_df, on="file_path")
df_train_pred = df_train_pred.merge(df_train_pred.groupby("study_id").agg(pred_max=("score_pred","max")).reset_index(), on="study_id")

In [None]:

val_df = pd.read_csv(params['csv_valid'])
val_ds = USDatasetBlindSweepWTag(val_df, mount_point=mount_point, img_column=params['img_column'], tag_column=params['tag_column'], id_column=None, transform=dm.test_transform, max_sweeps=-1, num_frames=-1)
val_dl = DataLoader(val_ds, batch_size=1, num_workers=2, shuffle=False)

df_val_pred = run_test(dm, model, val_dl, val_ds, output_csv=params['csv_valid'].replace('.csv', f'_rope{version}_predictions.csv'))
df_val_pred = df_val_pred.merge(val_df, on="file_path")
df_val_pred = df_val_pred.merge(df_val_pred.groupby("study_id").agg(pred_max=("score_pred","max")).reset_index(), on="study_id")

In [None]:
bins = np.array([0.0, 0.25, 0.5, 0.75, 1.0])

def to_class(x):
    return np.argmin(np.abs(bins - x))

df_test_pred["y_pred"] = df_test_pred.score_pred.apply(to_class)

In [None]:
file_paths = df_train_pred.query('pred_max < 0.3')['file_path'].drop_duplicates()
print(len(file_paths))
idx = 10
file_path = file_paths.iloc[idx]


print(file_path)
q = df_train_pred.query('file_path == @file_path')
score, score_pred = q['score'], q['score_pred']

img = sitk.ReadImage(os.path.join(mount_point, file_path))
img_np = sitk.GetArrayFromImage(img)  # [T,H,W]

print(img_np.shape, len(score), len(score_pred))

if img.GetNumberOfComponentsPerPixel() == 1:
    img_np = np.expand_dims(img_np, -1).repeat(3, axis=-1)

visualize_sequence(img_np, score, score_pred, title=f"{'/'.join(file_path.split('/')[-2:])}")

In [None]:
# out_fn = os.path.join(os.path.dirname(model_fn), os.path.splitext(os.path.basename(file_path))[0]) + ".mp4"

# export_widget_like_mp4(
#     images=img_np,                # (T,H,W,C) or (T,H,W)
#     scores=score,
#     scores_pred=score_pred,
#     out_mp4=out_fn,
#     fps=10.0,
#     title=file_path,
# )

In [None]:
train_df.columns

In [None]:
train_df[['study_id', 'efw_gt']].drop_duplicates().describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
train_df['efw_gt'].hist()

In [None]:
edges_g = torch.tensor([500, 1461.43, 1900, 2500, 3000, 3400, 3750, 3930.09, 4398.58, 5500])
centers_g = torch.tensor([980.7, 1680.7, 2200, 2750, 3200, 3575, 3840.0, 4164.3, 4949.3])
bins = (centers_g - 500)/5000

In [None]:
bins

In [None]:
df_train_pred[['study_id', 'pred_max']].drop_duplicates().describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_train_pred[['study_id', 'pred_max', 'dataset']].drop_duplicates().groupby('dataset').describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_train_pred[['study_id', 'pred_max', 'dataset', 'model']].drop_duplicates().groupby(['dataset', 'model']).describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_train_pred['score_pred'].hist(bins=50)

In [None]:
dxa_fn = "CSV_files/dxa_instance_EFW_with_bw_dataset_022426.csv"
dxa_df = pd.read_csv(os.path.join(mount_point, dxa_fn))
dxa_ds = USDatasetBlindSweepWTag(dxa_df, mount_point=mount_point, img_column='file_path', tag_column='tag', id_column=None, transform=dm.test_transform, max_sweeps=-1, num_frames=-1)
dxa_dl = DataLoader(dxa_ds, batch_size=1, num_workers=2, shuffle=False)

df_dxa_pred = run_test(dm, model, dxa_dl, dxa_ds, output_csv=dxa_fn.replace('.csv', f'_rope{version}_predictions.csv'))
df_dxa_pred = df_dxa_pred.merge(dxa_df, on="file_path")
df_dxa_pred = df_dxa_pred.merge(df_dxa_pred.groupby("study_id").agg(pred_max=("score_pred","max")).reset_index(), on="study_id")

In [None]:
df_dxa_pred[['study_id', 'pred_max']].drop_duplicates().describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_dxa_pred.columns

In [None]:
df_dxa_pred[['study_id', 'pred_max', 'model']].drop_duplicates().groupby('model').describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_train_pred[['study_id', 'pred_max', 'dataset', 'model']].drop_duplicates().groupby(['dataset', 'model']).describe(percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99,])

In [None]:
df_dxa_pred.head()

In [None]:
df_dxa_pred[df_dxa_pred['study_id'].str.startswith('DXA-300-0024')]

In [None]:
file_paths = df_dxa_pred.query('study_id == "DXA-300-0024_20230228-bfly-ga-tool-expert" and tag == "M"')['file_path'].drop_duplicates()

idx = 0
file_path = file_paths.iloc[idx]


print(file_path)
q = df_dxa_pred.query('file_path == @file_path')
score, score_pred = q['score'], q['score_pred']

img = sitk.ReadImage(os.path.join(mount_point, file_path))
img_np = sitk.GetArrayFromImage(img)  # [T,H,W]

print(img_np.shape, len(score), len(score_pred))

if img.GetNumberOfComponentsPerPixel() == 1:
    img_np = np.expand_dims(img_np, -1).repeat(3, axis=-1)

visualize_sequence(img_np, score, score_pred, title=f"{'/'.join(file_path.split('/')[-2:])}")

In [None]:

args_dict = {
    "mount_point": "/mnt/raid/C1_ML_Analysis/",
    "csv_train": "/mnt/raid/C1_ML_Analysis/CSV_files/efw_2025-10-31_train.csv",
    "csv_valid": "/mnt/raid/C1_ML_Analysis/CSV_files/efw_2025-10-31_val.csv",
    "csv_test": "/mnt/raid/C1_ML_Analysis/CSV_files/efw_2025-10-31_test.csv",
    "img_column": "file_path",
    "tag_column": "tag",    
    "id_column": "study_id",
    "max_sweeps": 3, 
    "ga_column": None,
    "frame_column": None,
    "presentation_column": None,
    "class_column": None,
    "efw_column": "efw_gt",
    "batch_size": 2,
    "num_frames": 96,
    "num_frames_val": 96, 
    "num_frames_test": 96,
    "num_workers": 16,
    "prefetch_factor": 4,
    "drop_last": False,
    "csv_train_ac": "/mnt/raid/C1_ML_Analysis/CSV_files//c3_blindsweep_annotation_labels_merged_train_train.csv"
}

batch_size = 2
dm = usd.USDataModuleBlindSweepWTag(
        **args_dict
    )
dm.setup()