In [4]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer
from neural_network.vmamba import VisualMamba

def test_forward():
    batch_size = 2
    img_size = 224
    num_classes = 10
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = VisualMamba(img_size=img_size, num_classes=num_classes).to(device)
    x = torch.randn(batch_size, 3, img_size, img_size, device=device)

    logits = model(x)
    print("Logits shape:", logits.shape)
    assert logits.shape == (batch_size, num_classes)

def test_training_step():
    if not torch.cuda.is_available():
        print("⚠️ Skipping training test: Mamba requires CUDA")
        return

    batch_size = 4
    img_size = 224
    num_classes = 10

    x = torch.randn(16, 3, img_size, img_size, device="cuda")
    y = torch.randint(0, num_classes, (16,), device="cuda")
    dataset = TensorDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    model = VisualMamba(img_size=img_size, num_classes=num_classes).cuda()

    trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1, fast_dev_run=True)
    trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)

    print("Training loop ran successfully ✅")

if __name__ == "__main__":
    test_forward()
    test_training_step()


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 

Logits shape: torch.Size([2, 10])



  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | model       | Sequential | 26.0 K | train
1 | patch_embed | Conv2d     | 147 K  | train
2 | backbone    | Sequential | 6.0 M  | train
3 | norm        | LayerNorm  | 384    | train
4 | head        | Linear     | 1.9 K  | train
---------------------------------------------------
6.2 M     Trainable params
0         Non-trainable params
6.2 M     Total params
24.850    Total estimated model params size (MB)
176       Modules in train mode
0         Modules in eval mode
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/andre/code/StudentMAE/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/d

Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s]
Training loop ran successfully ✅


In [None]:
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import Trainer
from neural_network.retfound import RETFoundClassifier  # adjust as needed

def test_retfound_forward(checkpoint_path: str):
    assert os.path.exists(checkpoint_path), f"Checkpoint not found at {checkpoint_path}"
    num_classes = 5
    batch_size = 2
    img_size = 224

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = RETFoundClassifier(num_classes=num_classes, checkpoint_path=checkpoint_path).to(device)
    x = torch.randn(batch_size, 3, img_size, img_size, device=device)
    y = torch.randint(0, num_classes, (batch_size,), device=device)

    logits = model(x)
    print("Output shape:", logits.shape)
    assert logits.shape == (batch_size, num_classes), f"Expected {(batch_size, num_classes)}, got {logits.shape}"

    loss = model.training_step((x, y), batch_idx=0)
    print("Training loss:", loss.item())

    model.validation_step((x, y), batch_idx=0)

def test_retfound_training_loop(checkpoint_path: str):
    # Skip full training if no GPU (depending on your model’s GPU requirements)
    if not torch.cuda.is_available():
        print("⚠️ Skipping training loop test: no CUDA available")
        return

    num_classes = 5
    img_size = 224
    batch_size = 4

    x = torch.randn(32, 3, img_size, img_size, device="cuda")
    y = torch.randint(0, num_classes, (32,), device="cuda")
    dataset = TensorDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = RETFoundClassifier(num_classes=num_classes, checkpoint_path=checkpoint_path).cuda()

    trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1, fast_dev_run=True)
    trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)

    print("Training loop completed successfully")

if __name__ == "__main__":
    # Put your local path where you stored the checkpoint
    ckpt = "models/RETFound_cfp_weights.pth"
    test_retfound_forward(ckpt)
    test_retfound_training_loop(ckpt)


  from .autonotebook import tqdm as notebook_tqdm


AssertionError: Checkpoint not found at models/retfound/RETFound_cfp_weights.pth