In [2]:
import cv2
import torch
import mlflow
import numpy as np
from torch.utils.data import DataLoader

from model import BoneAgeModel
from dataset_eval import BoneAgeEvalDataset
from gradcam import GradCAM

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("Male_Demo_GradCAM")


2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/01/18 16:09:25 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments
2026/01/18 16:09:26 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/01/18 16:09:26 INFO mlflow.store.db.utils: Updating database tables
2026/01/18 16:09:26 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/01/18 16:09:26 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2026/01/18 16:09:26 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/01/18 16:09:26 INFO alembic.runtime

<Experiment: artifact_location='file:d:/pw2/pw_male/mlruns/2', creation_time=1768731280356, experiment_id='2', last_update_time=1768731280356, lifecycle_stage='active', name='Male_Demo_GradCAM', tags={}>

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = BoneAgeModel().to(DEVICE)
model.load_state_dict(torch.load("male_boneage_model.pth", map_location=DEVICE))
model.eval()

cam = GradCAM(model, model.cnn[-1])




In [4]:
demo_ds = BoneAgeEvalDataset("demo/demo.csv", "demo/images")
demo_loader = DataLoader(demo_ds, batch_size=1, shuffle=False)


In [None]:
def enhance(img):
    img = (img * 255).astype(np.uint8)
    clahe = cv2.createCLAHE(2.0, (8,8))
    return clahe.apply(img)

with mlflow.start_run(run_name="male_demo_gradcam"):
    for i, (img, gt_grp, gt_age) in enumerate(demo_loader):
        img = img.to(DEVICE)

        logits, unc = model(img)
        pred_grp = torch.argmax(logits, 1).item()
        mu, lv = unc.chunk(2,1)

        cam_map = cam.generate(img, pred_grp)
        base = enhance(img.squeeze().cpu().numpy())

        heatmap = cv2.applyColorMap((cam_map*255).astype(np.uint8), cv2.COLORMAP_JET)
        heatmap = cv2.resize(heatmap, (224,224))

        overlay = cv2.addWeighted(
            cv2.cvtColor(base, cv2.COLOR_GRAY2BGR),
            0.6, heatmap, 0.4, 0
        )

        out = f"demo_{i}.png"
        cv2.imwrite(out, overlay)
        mlflow.log_artifact(out)

        mlflow.log_metric("predicted_age", mu.item())
        mlflow.log_metric("predicted_sigma", torch.exp(0.5*lv).item())
        mlflow.log_metric("predicted_group", pred_grp)
        mlflow.log_metric("ground_truth_group", gt_grp.item())

        print(
            f"GT age: {gt_age.item():.1f} | "
            f"Pred age: {mu.item():.1f} ± {torch.exp(0.5*lv).item():.1f} | "
            f"GT grp: {gt_grp.item()} | Pred grp: {pred_grp}"
        )


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


GT age: 60.0 | Pred age: 85.7 ± 13.7 | GT grp: 1 | Pred grp: 2
GT age: 168.0 | Pred age: 100.9 ± 69.9 | GT grp: 3 | Pred grp: 3
GT age: 174.0 | Pred age: 101.7 ± 75.2 | GT grp: 3 | Pred grp: 3
GT age: 144.0 | Pred age: 96.0 ± 54.5 | GT grp: 3 | Pred grp: 3
GT age: 108.0 | Pred age: 96.0 ± 30.8 | GT grp: 2 | Pred grp: 2
GT age: 28.0 | Pred age: 87.7 ± 69.5 | GT grp: 1 | Pred grp: 0
GT age: 138.0 | Pred age: 99.2 ± 33.2 | GT grp: 3 | Pred grp: 2
GT age: 108.0 | Pred age: 89.3 ± 13.5 | GT grp: 2 | Pred grp: 2
GT age: 165.0 | Pred age: 97.5 ± 48.9 | GT grp: 3 | Pred grp: 2
GT age: 138.0 | Pred age: 97.2 ± 31.4 | GT grp: 3 | Pred grp: 2
GT age: 138.0 | Pred age: 97.7 ± 45.9 | GT grp: 3 | Pred grp: 3
GT age: 150.0 | Pred age: 95.0 ± 30.7 | GT grp: 3 | Pred grp: 2
GT age: 162.0 | Pred age: 97.7 ± 56.7 | GT grp: 3 | Pred grp: 3
GT age: 66.0 | Pred age: 83.5 ± 22.1 | GT grp: 1 | Pred grp: 1
GT age: 168.0 | Pred age: 101.0 ± 73.2 | GT grp: 3 | Pred grp: 3
GT age: 162.0 | Pred age: 101.2 ± 84.2 |

: 