In [1]:
# Imports 
from pathlib import Path
import torch

In [2]:
# Import own helpers
from vesuvius.training.train import make_dataloaders
from vesuvius.models.unet3d_small import SmallUNet3D

In [3]:
# Root folder path
ROOT = Path("~/vesuvius-scroll-detection/data/raw/vesuvius").expanduser()

print("ROOT:", ROOT)

ROOT: /Users/chamu/vesuvius-scroll-detection/data/raw/vesuvius


In [4]:
# Device selection
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print("Device:", device)

Device: mps


In [5]:
# Create dataloaders
train_loader, val_loader, train_ds, val_ds = make_dataloaders(
    data_root=ROOT,
    batch_size=1,
    val_fraction=0.15,
    seed=0,
)

print("Train samples:", len(train_ds))
print("Val samples:", len(val_ds))

Train samples: 686
Val samples: 120


In [6]:
# Pull one batch
imgs, masks, sids = next(iter(train_loader))

print("imgs:", imgs.shape, imgs.dtype)
print("masks:", masks.shape, masks.dtype)
print("sids:", sids)



imgs: torch.Size([1, 1, 320, 320, 320]) torch.float32
masks: torch.Size([1, 1, 320, 320, 320]) torch.float32
sids: ('3076490891',)


In [7]:
model = SmallUNet3D(base_channels=8).to(device)
model.eval()

imgs = imgs.to(device)
masks = masks.to(device)

with torch.no_grad():
    logits = model(imgs)

print("logits:", logits.shape, logits.dtype)

assert logits.shape == masks.shape, "logits and masks shape mismatch!"
print("✅ Forward pass OK")

logits: torch.Size([1, 1, 320, 320, 320]) torch.float32
✅ Forward pass OK
