# RetFound Test

In [4]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer
from neural_network.vmamba import VisualMamba

def test_forward():
    batch_size = 2
    img_size = 224
    num_classes = 10
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = VisualMamba(img_size=img_size, num_classes=num_classes, mask_ratio=0.75).to(device)
    x = torch.randn(batch_size, 3, img_size, img_size, device=device)

    logits = model(x)
    print("Logits shape:", logits.shape)
    assert logits.shape == (batch_size, num_classes)

def test_training_step():
    if not torch.cuda.is_available():
        print("‚ö†Ô∏è Skipping training test: Mamba requires CUDA")
        return

    batch_size = 4
    img_size = 224
    num_classes = 10

    x = torch.randn(16, 3, img_size, img_size, device="cuda")
    y = torch.randint(0, num_classes, (16,), device="cuda")
    dataset = TensorDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    model = VisualMamba(img_size=img_size, num_classes=num_classes, mask_ratio=0.75).cuda()

    trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1, fast_dev_run=True)
    trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)

    print("Training loop ran successfully ‚úÖ")

if __name__ == "__main__":
    test_forward()
    test_training_step()


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RAN

Logits shape: torch.Size([2, 10])



  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | model       | Sequential | 26.0 K | train
1 | patch_embed | Conv2d     | 147 K  | train
2 | backbone    | Sequential | 6.0 M  | train
3 | norm        | LayerNorm  | 384    | train
4 | head        | Linear     | 1.9 K  | train
---------------------------------------------------
6.2 M     Trainable params
0         Non-trainable params
6.2 M     Total params
24.850    Total estimated model params size (MB)
176       Modules in train mode
0         Modules in eval mode
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/d

Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  1.87it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  1.86it/s]
Training loop ran successfully ‚úÖ


# Distillation Test

In [None]:
# main.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

from distillation.dist import EmbeddingDataset, distill_embeddings
from models.retfound import RETFoundClassifier
from models.vmamba import VisualMamba

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = RETFoundClassifier(
    checkpoint_path="checkpoints/RETFound_cfp_weights.pth"
).to(device)

student = VisualMamba(
    img_size=224,
    patch_size=16,
    in_chans=3,
    embed_dim=192,
    depth=24,
    learning_rate=1e-4
).to(device)

projector = nn.Linear(192, 1024).to(device)

dummy_data = torch.randn(1000, 3, 224, 224)
embed_loader = DataLoader(EmbeddingDataset(dummy_data), batch_size=32, shuffle=True)

# ---------------- phase 1 ----------------
distill_embeddings(
    teacher=teacher,
    student=student,
    projector=projector,
    dataloader=embed_loader,
    device=device,
    epochs=1,
)

torch.save(student.state_dict(), "vmamba_distilled.pth")

# ================================
# phase 2 lightning finetuning
# ================================

# # freeze backbone ‚Äî only train heads
# for name, p in student.named_parameters():
#     if not name.startswith("heads."):
#         p.requires_grad_(False)

# # your real labeled dataset
# train_loader = DataLoader(real_dataset, batch_size=16, shuffle=True)
# val_loader = DataLoader(real_val, batch_size=16)

# trainer = Trainer(
#     max_epochs=10,
#     accelerator="gpu" if torch.cuda.is_available() else "cpu",
# )

# trainer.fit(student, train_loader, val_loader)

# torch.save(student.state_dict(), "vmamba_heads_trained.pth")

Missing keys after loading checkpoint: ['fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias']
[Distill] epoch 1  loss=0.996899


In [None]:
import torch
import torch.nn as nn
from distillation.dist import distill_embeddings, train_head
from models.retfound import RETFoundClassifier
from models.vmamba import VisualMamba
from dataloader.idrid import IDRiDModule
from torchvision import transforms


# ---------------- setup ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = RETFoundClassifier(
    checkpoint_path="checkpoints/RETFound_cfp_weights.pth"
).to(device)

student = VisualMamba(
    img_size=224,
    patch_size=16,
    in_chans=3,
    embed_dim=192,
    depth=24,
    learning_rate=1e-4
).to(device)

projector = nn.Linear(192, 1024).to(device)


# ---------------- data ----------------
tfm = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

dm = IDRiDModule(
    root="./data/aaryapatel98/indian-diabetic-retinopathy-image-dataset/versions/1/B.%20Disease%20Grading/B. Disease Grading",
    transform=tfm,
    batch_size=2
)
dm.setup()
train_loader = dm.train_dataloader()
val_loader   = dm.val_dataloader()


# ---------------- phase 1 ----------------
distill_embeddings(
    teacher=teacher,
    student=student,
    projector=projector,
    dataloader=train_loader,
    device=device,
    epochs=1,
)

# ---------------- phase 2 ----------------
# train_head(
#     student=student,
#     dataloader=train_loader,
#     device=device,
#     epochs=1,
# )


# # save student
# student.eval()
# torch.save(student.state_dict(), "./checkpoints/vmamba_distilled.pth")
# print("saved ./checkpoints/vmamba_distilled.pth")

Missing keys after loading checkpoint: ['fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias']
Found 413 images for split=train
Found 103 images for split=test
[Distill] epoch 1  loss=0.545846


In [13]:
import pandas as pd
from config.constants import IDRID_PATH
from dataloader.idrid import IDRiDDataset
from torchvision import transforms

# --- CONFIG ---

tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# --- Load dataset ---
ds = IDRiDDataset(IDRID_PATH, split="train", transform=tfm)

# --- Build DataFrame ---
records = []
for p in ds.img_paths:
    img_id = p.split("/")[-1].replace(".jpg", "")
    label = ds.grade_map[img_id]
    records.append({"path": p, "label": label})

df = pd.DataFrame(records)

df.head(), df.label.value_counts()


Found 413 images for split=train


(                                                path  label
 0  ./data/aaryapatel98/indian-diabetic-retinopath...      3
 1  ./data/aaryapatel98/indian-diabetic-retinopath...      3
 2  ./data/aaryapatel98/indian-diabetic-retinopath...      2
 3  ./data/aaryapatel98/indian-diabetic-retinopath...      3
 4  ./data/aaryapatel98/indian-diabetic-retinopath...      4,
 label
 2    136
 0    134
 3     74
 4     49
 1     20
 Name: count, dtype: int64)

In [4]:
import torch
from torchvision import transforms
from config.constants import *
from dataloader.idrid import IDRiDDataset
from models.vmamba_backbone import VisualMamba
import pandas as pd

# ------------------------
# Set device
# ------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------
# 1. Load Dataset
# ------------------------
tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

ds = IDRiDDataset(IDRID_PATH, split="train", transform=tfm)
x, y = ds[0]    # one sample
x = x.unsqueeze(0).to(device)   # (1, 3, 224, 224)

print("Image shape:", x.shape)

# ------------------------
# 2. Create VisualMamba
# ------------------------
model = VisualMamba(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=IN_CHANS,
    embed_dim=VMAMBA_EMBED_DIM,
    depth=VMAMBA_DEPTH,
    learning_rate=0.0,
    mask_ratio=0.6,
)
model = model.to(device)
model.eval()

# ------------------------
# 3. Forward with masking
# ------------------------
seq, mask, ids_keep, ids_restore = model.forward_features(
    x, 
    return_pooled=False,
    apply_mask=True
)
print("\n--- VISUAL MAMBA MASKING TEST ---")
print("Patch seq shape:", seq.shape)
print("Mask shape:", mask.shape)
print("ids_keep shape:", ids_keep.shape)
print("ids_restore shape:", ids_restore.shape)

# ------------------------
# 4. Verify Mask Ratio
# ------------------------
B, N = mask.shape
num_visible = mask.sum().item()
num_masked = N - num_visible
print(f"\nTotal tokens: {N}")
print(f"Visible tokens: {num_visible}")
print(f"Masked tokens: {num_masked}")
print(f"Mask ratio observed: {num_masked/N:.2f} (target: 0.60)")

# ------------------------
# 5. Check ordering correctness
# ------------------------
seq_reconstructed = torch.gather(
    seq,
    dim=1,
    index=ids_restore.unsqueeze(-1).expand(-1, -1, model.embed_dim)
)

print("\nReconstructed seq shape:", seq_reconstructed.shape)

# ------------------------
# 6. Check pooled output correctness
# ------------------------
pooled = model.forward_features(
    x,
    return_pooled=True,
    apply_mask=True
)

print("Pooled masked feature shape:", pooled.shape)


Using device: cuda
Found 413 images for split=train
Image shape: torch.Size([1, 3, 224, 224])

--- VISUAL MAMBA MASKING TEST ---
Patch seq shape: torch.Size([1, 196, 192])
Mask shape: torch.Size([1, 196])
ids_keep shape: torch.Size([1, 79])
ids_restore shape: torch.Size([1, 196])

Total tokens: 196
Visible tokens: 79.0
Masked tokens: 117.0
Mask ratio observed: 0.60 (target: 0.60)

Reconstructed seq shape: torch.Size([1, 196, 192])
Pooled masked feature shape: torch.Size([1, 192])


In [10]:
# predict_vmamba_idrid.py

import os
import torch
import pandas as pd
import pytorch_lightning as pl
from torchvision import transforms
from config.constants import *
from utils.vmamba_idrid import VmambaClassifier
from dataloader.idrid import IDRiDModule


MODEL_PATH = "vmamba_full_supervised/gg6er7wp/checkpoints/vmamba_supervised_best.ckpt"
CSV_OUT = "vmamba_idrid_predictions.csv"


def run_prediction():
    pl.seed_everything(42)

    # Transform ‚Äî MUST match training
    tfm = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # Test datamodule
    dm = IDRiDModule(
        root=IDRID_PATH,
        transform=tfm,
        batch_size=1,
    )
    dm.setup(stage="test")

    # --------------------------------------------------------
    # LOAD MODEL SAFELY
    # --------------------------------------------------------
    print(f"Loading checkpoint: {MODEL_PATH}")

    model = VmambaClassifier.load_from_checkpoint(
        MODEL_PATH,
        lr=1e-4,
        class_weights=None,
        strict=False
    )

    model.eval()
    model.freeze()

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
    )

    print("Running predictions...")
    outputs = trainer.predict(model, datamodule=dm)

    # --------------------------------------------------------
    # FLATTEN PREDICTIONS
    # --------------------------------------------------------
    paths, labels, preds, all_probs = [], [], [], []

    for batch in outputs:
        paths.extend(batch["paths"])
        labels.extend(batch["labels"].tolist())
        preds.extend(batch["preds"].tolist())
        all_probs.extend(batch["probs"].tolist())

    # --------------------------------------------------------
    # SAVE TO CSV
    # --------------------------------------------------------
    df = pd.DataFrame({
        "image_path": paths,
        "true_label": labels,
        "pred_label": preds,
    })

    probs_df = pd.DataFrame(
        all_probs,
        columns=[f"prob_{i}" for i in range(len(all_probs[0]))],
    )

    df = pd.concat([df, probs_df], axis=1)
    df.to_csv(CSV_OUT, index=False)

    print(f"[‚úì] Saved predictions to {CSV_OUT}")


if __name__ == "__main__":
    run_prediction()

Seed set to 42


Found 413 images for split=train
Found 103 images for split=test
Loading checkpoint: vmamba_full_supervised/gg6er7wp/checkpoints/vmamba_supervised_best.ckpt


/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['class_weights', 'loss_fn.weight']
üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Running predictions...
Found 413 images for split=train
Found 103 images for split=test


Predicting: |          | 0/? [00:00<?, ?it/s]

[‚úì] Saved predictions to vmamba_idrid_predictions.csv


In [3]:
import os
import pandas as pd
from dataloader.idrid import IDRiDDataset, IDRiDModule
from config.constants import IDRID_PATH, IMG_SIZE
from torchvision import transforms
from collections import Counter

print("\n=== RUNNING IDRiD DIAGNOSTICS ===\n")

# Same transform used in training
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

# Load your datamodule exactly as before
dm = IDRiDModule(root=IDRID_PATH, transform=tfm, batch_size=16)
dm.setup()

train_ds = dm.train
val_ds   = dm.val

# ------------------------------------
# 1. PRINT DATASET SIZES
# ------------------------------------
print(f"Train set size: {len(train_ds)}")
print(f"Val/Test set size: {len(val_ds)}")

# ------------------------------------
# 2. CHECK LABEL DISTRIBUTION
# ------------------------------------
def count_labels(ds):
    labels = []
    for i in range(len(ds)):
        _, y, _ = ds[i]
        labels.append(int(y))
    return Counter(labels)

print("\nTrain label distribution:", count_labels(train_ds))
print("Val/Test label distribution:", count_labels(val_ds))

# ------------------------------------
# 3. PRINT FIRST 10 VAL LABELS
# ------------------------------------
print("\nFirst 10 validation labels:")
for i in range(min(10, len(val_ds))):
    _, y, p = val_ds[i]
    print(os.path.basename(p), "‚Üí", y)

# ------------------------------------
# 4. CHECK IF TEST CSV HAS REAL LABELS
# ------------------------------------
test_csv = os.path.join(
    IDRID_PATH, "2. Groundtruths", "b. IDRiD_Disease Grading_Testing Labels.csv"
)
if os.path.exists(test_csv):
    df_test = pd.read_csv(test_csv)
    print("\nTest CSV head:")
    print(df_test.head())
    print("\nUnique 'Retinopathy grade' values in Test CSV:", df_test["Retinopathy grade"].unique())

# ------------------------------------
# 5. CHECK IF ANY VAL LABEL IS NaN ‚Üí FORCED TO 0
# ------------------------------------
nan_count = 0
for name, grade in val_ds.grade_map.items():
    if pd.isna(grade):
        nan_count += 1

print(f"\nNaN labels detected in test CSV: {nan_count}")

if nan_count > 0:
    print("‚ö†Ô∏è  WARNING: Test CSV contains NaN labels ‚Üí all NaN become class 0")
else:
    print("No NaN labels in validation mapping.")

print("\n=== DIAGNOSTICS COMPLETE ===\n")



=== RUNNING IDRiD DIAGNOSTICS ===

Found 413 images for split=train
Found 103 images for split=test
Train set size: 413
Val/Test set size: 103

Train label distribution: Counter({2: 136, 0: 134, 3: 74, 4: 49, 1: 20})
Val/Test label distribution: Counter({0: 34, 2: 32, 3: 19, 4: 13, 1: 5})

First 10 validation labels:
IDRiD_001.jpg ‚Üí 4
IDRiD_002.jpg ‚Üí 4
IDRiD_003.jpg ‚Üí 4
IDRiD_004.jpg ‚Üí 4
IDRiD_005.jpg ‚Üí 4
IDRiD_006.jpg ‚Üí 3
IDRiD_007.jpg ‚Üí 3
IDRiD_008.jpg ‚Üí 2
IDRiD_009.jpg ‚Üí 2
IDRiD_010.jpg ‚Üí 2

Test CSV head:
  Image name  Retinopathy grade  Risk of macular edema 
0  IDRiD_001                  4                       0
1  IDRiD_002                  4                       1
2  IDRiD_003                  4                       0
3  IDRiD_004                  4                       0
4  IDRiD_005                  4                       1

Unique 'Retinopathy grade' values in Test CSV: [4 3 2 0 1]

NaN labels detected in test CSV: 0
No NaN labels in validation mapp

In [4]:
from utils.vmamba_idrid import VmambaClassifier

model = VmambaClassifier.load_from_checkpoint(MODEL_PATH)
model.eval()

print("predict_step source:\n")
import inspect
print(inspect.getsource(model.predict_step))


predict_step source:

    def predict_step(self, batch, batch_idx):
        x, y, paths = batch
        logits = self(x)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)

        return {
            "paths": paths,
            "labels": y.cpu(),
            "preds": preds.cpu(),
            "probs": probs.cpu(),
        }



In [7]:
import torch

ckpt = torch.load("vmamba_full_supervised/gg6er7wp/checkpoints/vmamba_supervised_best.ckpt", map_location="cpu")

print("Keys in checkpoint:")
print(list(ckpt.keys()))

if "state_dict" in ckpt:
    print("\nState dict keys:")
    for k in list(ckpt["state_dict"].keys())[:50]:
        print(k)
else:
    print("\nRaw top-level keys:")
    for k in list(ckpt.keys())[:50]:
        print(k)


Keys in checkpoint:
['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision', 'hparams_name', 'hyper_parameters']

State dict keys:
class_weights
backbone.mask_token
backbone.patch_embed.weight
backbone.patch_embed.bias
backbone.backbone.0.A_log
backbone.backbone.0.D
backbone.backbone.0.in_proj.weight
backbone.backbone.0.conv1d.weight
backbone.backbone.0.conv1d.bias
backbone.backbone.0.x_proj.weight
backbone.backbone.0.dt_proj.weight
backbone.backbone.0.dt_proj.bias
backbone.backbone.0.out_proj.weight
backbone.backbone.1.A_log
backbone.backbone.1.D
backbone.backbone.1.in_proj.weight
backbone.backbone.1.conv1d.weight
backbone.backbone.1.conv1d.bias
backbone.backbone.1.x_proj.weight
backbone.backbone.1.dt_proj.weight
backbone.backbone.1.dt_proj.bias
backbone.backbone.1.out_proj.weight
backbone.backbone.2.A_log
backbone.backbone.2.D
backbone.backbone.2.in_proj.weight
backbone.backbone.2.conv1d.weight
ba

In [8]:
model = VmambaClassifier.load_from_checkpoint(MODEL_PATH, lr=1e-4, class_weights=None, strict=False)

# Count how many params successfully loaded
ckpt = torch.load(MODEL_PATH, map_location="cpu")["state_dict"]
model_state = model.state_dict()

matched = 0
for k in ckpt:
    if k in model_state and ckpt[k].shape == model_state[k].shape:
        matched += 1

print("Matched weight tensors:", matched)
print("Total in checkpoint:", len(ckpt))


Matched weight tensors: 223
Total in checkpoint: 225


/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['class_weights', 'loss_fn.weight']


In [13]:
import torch
from torch.utils.data import DataLoader
from dataloader.idrid import IDRiDDataset
from config.constants import IDRID_PATH, IMG_SIZE
from torchvision import transforms

print("\n=== RUNNING MODEL + DATALOADER DIAGNOSTICS ===\n")

# ------------------------------------
# 0. BUILD THE SAME TRANSFORM
# ------------------------------------
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

# ------------------------------------
# 1. BUILD PREDICT DATASET EXACTLY LIKE trainer.predict()
# ------------------------------------
predict_ds = IDRiDDataset(
    root=IDRID_PATH,
    split="test",       # or "val" depending on your use
    transform=tfm
)

predict_loader = DataLoader(
    predict_ds,
    batch_size=16,
    shuffle=False,
    num_workers=4
)

print(f"Predict dataset size: {len(predict_ds)}")

# ------------------------------------
# 2. LOAD ONE BATCH
# ------------------------------------
batch = next(iter(predict_loader))
x, y, paths = batch

print(f"\nLoaded batch of {x.shape[0]} images")

# ------------------------------------
# 3. MOVE BATCH TO MODEL DEVICE
# ------------------------------------
device = next(model.parameters()).device
x = x.to(device)

# ------------------------------------
# 4. RUN MODEL ON BATCH
# ------------------------------------
model.eval()
with torch.no_grad():
    logits = model(x)
    probs = torch.softmax(logits, dim=1)
    preds = torch.argmax(probs, dim=1)

print("\n=== SAMPLE OUTPUT ===")
print("True labels:", y.tolist())
print("Pred labels:", preds.cpu().tolist())
print("Sample probs:", probs[0].cpu().tolist())

# ------------------------------------
# 5. CHECK IF ALL PROBS ARE IDENTICAL
# ------------------------------------
all_same = torch.allclose(
    probs.cpu(), 
    probs[0].cpu().unsqueeze(0).expand_as(probs.cpu())
)
print("\nAll probs identical across batch?:", all_same)

# ------------------------------------
# 6. CHECK CLASSIFIER HEAD PARAMETERS
# ------------------------------------
print("\n=== CLASSIFIER HEAD PARAM STATS ===")
for name, p in model.named_parameters():
    if "head" in name.lower() or "fc" in name.lower():
        print(name, " | mean:", float(p.data.mean()),
              " | std:", float(p.data.std()),
              " | grad:", p.requires_grad)

# ------------------------------------
# 7. CHECK CHECKPOINT ACTUALLY LOADED NON-ZERO WEIGHTS
# ------------------------------------
sd = model.state_dict()
total_nonzero = sum((v != 0).sum().item() for v in sd.values())
print("\nNon-zero parameters in model:", total_nonzero)

# ------------------------------------
# 8. COUNT TRAINABLE VS FROZEN PARAMS
# ------------------------------------
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"\nTrainable params: {trainable}")
print(f"Frozen params:    {frozen}")

# ------------------------------------
# 9. SHOW CLASSIFIER WEIGHT SAMPLE
# ------------------------------------
for name, p in model.named_parameters():
    if "head" in name.lower():
        print("\nHead weight sample:", p.flatten()[:20])
        break

print("\n=== DIAGNOSTICS COMPLETE ===\n")



=== RUNNING MODEL + DATALOADER DIAGNOSTICS ===

Found 103 images for split=test
Predict dataset size: 103



Loaded batch of 16 images

=== SAMPLE OUTPUT ===
True labels: [4, 4, 4, 4, 4, 3, 3, 2, 2, 2, 2, 2, 3, 3, 2, 3]
Pred labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Sample probs: [0.20288433134555817, 0.19833576679229736, 0.19640283286571503, 0.20164941251277924, 0.20072759687900543]

All probs identical across batch?: True

=== CLASSIFIER HEAD PARAM STATS ===
head.weight  | mean: -0.002200536197051406  | std: 0.03709734231233597  | grad: True
head.bias  | mean: -0.025173192843794823  | std: 0.020051641389727592  | grad: True

Non-zero parameters in model: 6176453

Trainable params: 6185669
Frozen params:    0

Head weight sample: tensor([-0.0545,  0.0025,  0.0117,  0.0068, -0.0585,  0.0149, -0.0204,  0.0328,
         0.0521, -0.0157,  0.0304, -0.0163, -0.0148, -0.0668,  0.0422,  0.0294,
         0.0666,  0.0270,  0.0706,  0.0269], device='cuda:0',
       grad_fn=<SliceBackward0>)

=== DIAGNOSTICS COMPLETE ===



In [14]:
# 1) inspect checkpoint keys
ckpt = torch.load(MODEL_PATH, map_location="cpu")
print("\n--- Checkpoint root keys ---")
print(ckpt.keys())

print("\n--- First 50 keys inside state_dict ---")
for k in list(ckpt["state_dict"].keys())[:50]:
    print(k)

# 2) compare to your model keys
print("\n--- First 50 keys in current model state_dict ---")
for k in list(model.state_dict().keys())[:50]:
    print(k)



--- Checkpoint root keys ---
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision', 'hparams_name', 'hyper_parameters'])

--- First 50 keys inside state_dict ---
class_weights
backbone.mask_token
backbone.patch_embed.weight
backbone.patch_embed.bias
backbone.backbone.0.A_log
backbone.backbone.0.D
backbone.backbone.0.in_proj.weight
backbone.backbone.0.conv1d.weight
backbone.backbone.0.conv1d.bias
backbone.backbone.0.x_proj.weight
backbone.backbone.0.dt_proj.weight
backbone.backbone.0.dt_proj.bias
backbone.backbone.0.out_proj.weight
backbone.backbone.1.A_log
backbone.backbone.1.D
backbone.backbone.1.in_proj.weight
backbone.backbone.1.conv1d.weight
backbone.backbone.1.conv1d.bias
backbone.backbone.1.x_proj.weight
backbone.backbone.1.dt_proj.weight
backbone.backbone.1.dt_proj.bias
backbone.backbone.1.out_proj.weight
backbone.backbone.2.A_log
backbone.backbone.2.D
backbone.backbone.2.in_proj

In [15]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from dataloader.idrid import IDRiDDataset
from config.constants import IDRID_PATH, IMG_SIZE
from torchvision import transforms

print("=== BACKBONE INTERNAL DIAGNOSTICS ===\n")

# build dataset + loader (match what you used)
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])
ds = IDRiDDataset(root=IDRID_PATH, split="test", transform=tfm)
loader = DataLoader(ds, batch_size=8, shuffle=False, num_workers=4)

# get one batch
batch = next(iter(loader))
x_cpu, y, paths = batch
print("Loaded batch size:", x_cpu.shape)

# move to model device
device = next(model.parameters()).device
x = x_cpu.to(device)

# 1) input per-image stats (quick sanity)
pixel_means = x.view(x.shape[0], -1).mean(dim=1).cpu().tolist()
pixel_stds  = x.view(x.shape[0], -1).std(dim=1).cpu().tolist()
print("\nInput pixel means (per image):", np.round(pixel_means, 6))
print("Input pixel stds  (per image):", np.round(pixel_stds, 6))

# 2) patch_embed conv output stats
with torch.no_grad():
    pe = model.backbone.patch_embed(x)   # shape (B, D, H', W')
    B, D, H_, W_ = pe.shape
    pe_flat = pe.flatten(2).transpose(1,2)   # (B, N, D)
    pe_pool = pe_flat.mean(dim=1)            # per-sample pooled vector

print("\nPatch-embed shape:", pe.shape, "-> flattened:", pe_flat.shape)
print("Patch-embed pooled mean per sample:", np.round(pe_pool.mean(dim=1).cpu().numpy(), 6))
print("Patch-embed pooled std per sample:", np.round(pe_pool.std(dim=1).cpu().numpy(), 6))

# are pe_pool vectors identical?
identical_pe = torch.allclose(pe_pool, pe_pool[0].unsqueeze(0).expand_as(pe_pool))
print("Patch-embed pooled identical across batch?:", bool(identical_pe))

# 3) backbone output BEFORE norm
with torch.no_grad():
    # forward through backbone modules manually to inspect intermediate
    x_seq = pe_flat  # (B, N, D)
    # if backbone is nn.Sequential of Mamba blocks:
    for i, block in enumerate(model.backbone.backbone if hasattr(model.backbone, "backbone") else model.backbone):
        x_seq = block(x_seq)
        if i == 0:
            first_block_out = x_seq.detach().cpu()
    # after full backbone, run norm
    x_seq_after = model.backbone.norm(x_seq)  # (B, N, D)
    pooled = x_seq_after.mean(dim=1)          # final pooled features

print("\nBackbone output shape:", x_seq_after.shape)
print("Pooled features mean per sample:", np.round(pooled.mean(dim=1).cpu().numpy(), 6))
print("Pooled features std  per sample:", np.round(pooled.std(dim=1).cpu().numpy(), 6))
identical_pooled = torch.allclose(pooled, pooled[0].unsqueeze(0).expand_as(pooled))
print("Pooled features identical across batch?:", bool(identical_pooled))

# 4) Check first-block outputs variance
print("\nFirst backbone block output stats (sample):")
print(" mean per sample:", np.round(first_block_out.mean(dim=(1,2)).numpy(), 6))
print(" std per sample: ", np.round(first_block_out.std(dim=(1,2)).numpy(), 6))
identical_first_block = bool(torch.allclose(first_block_out, first_block_out[0].unsqueeze(0).expand_as(first_block_out)))
print("First block output identical across batch?:", identical_first_block)

# 5) Parameter sanity for patch_embed and first Mamba block
print("\nParameter stats:")
pe_w = model.backbone.patch_embed.weight.detach().cpu()
print(" patch_embed.weight mean:", float(pe_w.mean()), " std:", float(pe_w.std()))
print(" patch_embed.weight any NaN?", bool(torch.isnan(pe_w).any()))
print(" patch_embed.bias any NaN?", bool(torch.isnan(model.backbone.patch_embed.bias.detach().cpu()).any()))

# try printing a sample param from first block (if exists)
first_param_found = False
for n, p in model.named_parameters():
    if n.startswith("backbone.backbone.0"):
        print(f" param {n} | mean {float(p.data.mean()):.6f} std {float(p.data.std()):.6f}")
        first_param_found = True
        break
if not first_param_found:
    print(" No param path backbone.backbone.0.* found (structure differs).")

# 6) detect any NaNs/Infs in pooled features
has_nan = torch.isnan(pooled).any().item()
has_inf = torch.isinf(pooled).any().item()
print("\npooled has NaN?", has_nan, " pooled has Inf?", has_inf)

# 7) quick compare two different images individually
with torch.no_grad():
    f0 = pooled[0].cpu().numpy()
    f1 = pooled[1].cpu().numpy()
    cos = np.dot(f0, f1) / (np.linalg.norm(f0) * np.linalg.norm(f1) + 1e-12)
print("\ncosine between pooled feat[0] and feat[1]:", float(cos))

print("\n=== DIAGNOSTICS DONE ===")


=== BACKBONE INTERNAL DIAGNOSTICS ===

Found 103 images for split=test
Loaded batch size: torch.Size([8, 3, 224, 224])

Input pixel means (per image): [0.097906 0.147802 0.184873 0.240009 0.158269 0.223339 0.224439 0.263004]
Input pixel stds  (per image): [0.129837 0.193107 0.231126 0.215861 0.160772 0.218617 0.268931 0.258152]

Patch-embed shape: torch.Size([8, 192, 14, 14]) -> flattened: torch.Size([8, 196, 192])
Patch-embed pooled mean per sample: [0.00121  0.001919 0.002437 0.007141 0.004412 0.005794 0.003175 0.005639]
Patch-embed pooled std per sample: [0.072899 0.106935 0.132436 0.141808 0.097445 0.138957 0.157091 0.16712 ]
Patch-embed pooled identical across batch?: False

Backbone output shape: torch.Size([8, 196, 192])
Pooled features mean per sample: [-0.000123 -0.000123 -0.000123 -0.000123 -0.000123 -0.000123 -0.000123
 -0.000123]
Pooled features std  per sample: [0.008024 0.008024 0.008024 0.008024 0.008024 0.008024 0.008024 0.008024]
Pooled features identical across batch?