In [1]:
# Imports
from pathlib import Path
import torch
import matplotlib.pyplot as plt

from torch.utils.data import random_split, DataLoader, Subset

In [2]:
# Data
from vesuvius.data.vesuvius_train import VesuviusTrainDataset
from vesuvius.data.subset_with_transform import SubsetWithTransform

# Transforms
from vesuvius.transforms.get_transforms import get_transformations

# Model / training
from vesuvius.models.unet3d_small import SmallUNet3D
from vesuvius.losses import DiceBCELoss
from vesuvius.training.train_loop import train_loop

# Visualization
from vesuvius.visualization.prediction_viz import plot_prediction_triplet

In [3]:
# Root folder path
ROOT = Path("~/vesuvius-scroll-detection/data/raw/vesuvius").expanduser()

print("ROOT:", ROOT)

ROOT: /Users/chamu/vesuvius-scroll-detection/data/raw/vesuvius


In [4]:
# Device selection
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print("Device:", device)

Device: mps


In [5]:
# Load dataset + split
ds = VesuviusTrainDataset(ROOT)
print("Dataset size:", len(ds))

val_fraction = 0.15
seed = 0

n_total = len(ds)
n_val = max(1, int(val_fraction * n_total))
n_train = n_total - n_val

g = torch.Generator().manual_seed(seed)
train_subset, val_subset = random_split(ds, [n_train, n_val], generator=g)

print("Train subset:", len(train_subset))
print("Val subset:", len(val_subset))

Dataset size: 806
Train subset: 686
Val subset: 120


In [6]:
# Define patch + build transforms
PATCH = (32, 96, 96)
print("PATCH:", PATCH)

# Deterministic (val/test)
val_tf, train_tf = get_transformations(
    crop_size=PATCH,
    use_random_crop_for_train=True,
    mean=None,
    std=None,
    use_flips=True,
    use_intensity_jitter=True,
    use_gaussian_noise=True,
)

PATCH: (32, 96, 96)


In [7]:
# Wrap datasets
train_ds = SubsetWithTransform(train_subset, transform=train_tf)
val_ds   = SubsetWithTransform(val_subset, transform=val_tf)

train_loader = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=0,
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

imgs, masks, sids = next(iter(train_loader))
print("Train batch:", imgs.shape, masks.shape, "sid:", sids)
print("Mask unique:", torch.unique(masks))

Train batch: torch.Size([1, 1, 32, 96, 96]) torch.Size([1, 1, 32, 96, 96]) sid: ('1820528268',)
Mask unique: tensor([0., 1.])


In [8]:
# Model loss + optimizer
model = SmallUNet3D(in_channels=1, base_channels=16).to(device)

loss_fn = DiceBCELoss(
    w_bce=0.5,
    w_dice=0.5,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-5,
)

In [9]:
# Train loop
history = train_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    epochs=40,
    log_every=5,
    threshold=0.5,
)

KeyboardInterrupt: 

In [None]:
# Training curves
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.plot(history.train_loss, label="train")
plt.plot(history.val_loss, label="val")
plt.title("Loss")
plt.legend()

plt.subplot(1,3,2)
plt.plot(history.dice)
plt.title("Dice")

plt.subplot(1,3,3)
plt.plot(history.train_loss)
plt.title("Train loss (zoomed)")

plt.tight_layout()
plt.show()

In [None]:
# Visual sanity check
plot_prediction_triplet(
    model=model,
    loader=val_loader,
    device=device,
    threshold=0.3,
    title_prefix="Augmented model",
)

In [None]:
# Threshold sensitivity
for thr in [0.2, 0.3, 0.5, 0.7]:
    plot_prediction_triplet(
        model=model,
        loader=val_loader,
        device=device,
        threshold=thr,
        title_prefix=f"thr={thr}",
    )