In [None]:
import torch
import pandas as pd
import os

import numpy as np
import pandas as pd

from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt

import io

import ipywidgets as widgets
from PIL import Image
import io

import plotly.graph_objects as go


import sys
sys.path.append("/mnt/raid/C1_ML_Analysis/source/us-famli-pl/src/")
from nets.classification import RopeEffnetV2s
from loaders.ultrasound_dataset import USAnnotatedBlindSweepDataModule

In [None]:

def _frame_to_png_bytes(frame2d: np.ndarray) -> bytes:
    """Convert a (C,H,W) frame (any numeric dtype) to PNG bytes for ipywidgets.Image."""
    f = np.asarray(frame2d)
    # Handle NaNs/infs safely
    f = np.nan_to_num(f, nan=0.0, posinf=0.0, neginf=0.0)

    # Normalize to uint8
    fmin = float(f.min())
    fmax = float(f.max())
    if fmax > fmin:
        u8 = ((f - fmin) / (fmax - fmin) * 255.0).astype(np.uint8)
    else:
        u8 = np.zeros_like(f, dtype=np.uint8)

    img = Image.fromarray(u8, mode="RGB")  # RGB
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()


def visualize_sequence(
    images: np.ndarray,
    scores: np.ndarray,
    scores_pred: np.ndarray,
    *,
    fps: float = 10.0,
    title: str = "Sequence",
    show_filename: bool = False,
    filenames=None,
):
    """
    Jupyter-stable interactive viewer:
      - Left: grayscale frame shown via ipywidgets.Image
      - Right: Plotly line plots for GT + Pred with moving markers + vertical line
      - Controls: Play + Slider

    Parameters
    ----------
    images : np.ndarray
        Shape (T,H,W) or (C,T,H,W) or (T,H,W,C) supported if you tweak below.
        (Your code uses a helper _frame_to_png_bytes(frame) that expects a single frame.)
    scores : array-like
        Shape (T,), GT score per frame.
    scores_pred : array-like
        Shape (T,), predicted score per frame.
    """
    images = np.asarray(images)
    if images.ndim != 4:
        raise ValueError(f"`images` must have shape (T,H,W,C) (per your current code). Got {images.shape}")

    # Your original code says (C,T,H,W) but then unpacks as (T,H,W,C).
    # Keeping your current behavior: images.shape == (T,H,W,C)
    T, H, W, C = images.shape

    scores_np = np.asarray(scores).reshape(-1)
    if scores_np.shape[0] != T:
        raise ValueError(f"`scores` length must match T={T}. Got {scores_np.shape[0]}")

    scores_pred_np = np.asarray(scores_pred).reshape(-1)
    if scores_pred_np.shape[0] != T:
        raise ValueError(f"`scores_pred` length must match T={T}. Got {scores_pred_np.shape[0]}")

    if show_filename:
        if filenames is None:
            raise ValueError("show_filename=True requires `filenames=`")
        filenames = np.asarray(filenames).reshape(-1)
        if filenames.shape[0] != T:
            raise ValueError(f"`filenames` length must match T={T}. Got {filenames.shape[0]}")

    # --- Widgets ---
    slider = widgets.IntSlider(value=0, min=0, max=T - 1, step=1, description="Frame", continuous_update=True)
    play = widgets.Play(
        value=0, min=0, max=T - 1, step=1,
        interval=int(1000 / max(fps, 1e-6)),
        description="Play",
    )
    widgets.jslink((play, "value"), (slider, "value"))

    # Image display (PNG bytes)
    img_w = widgets.Image(value=_frame_to_png_bytes(images[0]), format="png")

    # Metadata panel
    meta = widgets.HTML()
    def _meta_html(i: int) -> str:
        base = (
            f"<b>frame:</b> {i}"
            f"<br><b>score (gt):</b> {scores_np[i]:.3f}"
            f"<br><b>score (pred):</b> {scores_pred_np[i]:.3f}"
        )
        if show_filename:
            base += f"<br><b>file:</b> {filenames[i]}"
        return base

    meta.value = _meta_html(0)

    # Plotly FigureWidget
    x = np.arange(T)
    fig = go.FigureWidget()

    # Lines
    fig.add_scatter(x=x, y=scores_np,      mode="lines", name="score (gt)")
    fig.add_scatter(x=x, y=scores_pred_np, mode="lines", name="score (pred)")

    # Current markers (one for each series)
    fig.add_scatter(x=[0], y=[float(scores_np[0])],      mode="markers", name="current (gt)")
    fig.add_scatter(x=[0], y=[float(scores_pred_np[0])], mode="markers", name="current (pred)")

    # y-range for the vertical line: cover both series
    y_min = float(np.min([scores_np.min(), scores_pred_np.min()]))
    y_max = float(np.max([scores_np.max(), scores_pred_np.max()]))

    fig.update_layout(
        title=title,
        xaxis_title="Frame",
        yaxis_title="Score",
        margin=dict(l=40, r=10, t=40, b=40),
        shapes=[
            dict(
                type="line",
                x0=0, x1=0,
                y0=y_min, y1=y_max,
                xref="x", yref="y",
                line=dict(width=2, dash="dash"),
            )
        ],
    )

    def _update(i: int):
        # Update image
        img_w.value = _frame_to_png_bytes(images[i])

        # Update markers + vertical line
        with fig.batch_update():
            # marker traces are indices 2 and 3
            fig.data[2].x = (i,)
            fig.data[2].y = (float(scores_np[i]),)

            fig.data[3].x = (i,)
            fig.data[3].y = (float(scores_pred_np[i]),)

            # keep line spanning both series ranges (or recompute if you want dynamic)
            fig.layout.shapes[0].update(x0=i, x1=i, y0=y_min, y1=y_max)

        # Update metadata
        meta.value = _meta_html(i)

    slider.observe(lambda ch: _update(ch["new"]), names="value")

    controls = widgets.HBox([play, slider])
    left = widgets.VBox([meta, img_w])
    right = widgets.VBox([fig])
    ui = widgets.VBox([controls, widgets.HBox([left, right])])

    return ui


In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'
model = RopeEffnetV2s.load_from_checkpoint(os.path.join(mount_point, 'train_output/classification/RopeEffnet/v0.6/', 'epoch=41-val_loss=0.03.ckpt'))
model.cuda()
model.eval()

In [None]:
dm = USAnnotatedBlindSweepDataModule(**model.hparams)
dm.setup()
test_ds = dm.test_ds

In [None]:

file_path = test_ds.df_frames.query('annotation_label == "high_measurable" and tag != "AC"')['file_path'].drop_duplicates().sample(n=1).values[0]
# file_path = test_ds.df_frames.query('annotation_label == "low_visible"')['file_path'].drop_duplicates().sample(n=1).values[0]
idx = test_ds.df.query(f'file_path == "{file_path}"').index[0] 

X_d = test_ds[idx]

with torch.no_grad():
    logits = model(X_d['img'].unsqueeze(0).permute(0, 2, 1, 3, 4).cuda())
    probs = torch.softmax(logits, dim=-1)
    levels = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0], device=probs.device)
    score = (probs * levels).sum(dim=-1)  # [B, N]


visualize_sequence(X_d['img'].permute(1,2,3,0), X_d['scalar'], score.squeeze(0).cpu().numpy(), title=f"{'/'.join(file_path.split('/')[-2:])}")

In [None]:
w_bins = torch.arange(1, 6, dtype=torch.float32) 
w_bins = w_bins / w_bins.mean()

In [None]:
w_bins

In [None]:
c = np.array([369458, 92519, 15579, 7674, 5970])
w = 1.0 / c
w = w / w.mean()
w


In [None]:
wb = np.array([0.2,0.4,0.6,1.0,2.0])
wb = wb / wb.mean()
wb