In [3]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader, random_split

from vesuvius.data.vesuvius_train import VesuviusTrainDataset
from vesuvius.data.patch25d_dataset import Patch25DDataset
from vesuvius.transforms.volume25d import RandomCrop25D, NormalizePerPatch
from vesuvius.models.unet2d import UNet2D
from vesuvius.losses.dice_bce import DiceBCE

def main():
    ROOT = Path("~/vesuvius-scroll-detection/data/raw/vesuvius").expanduser()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print("Device:", device)

    base_ds = VesuviusTrainDataset(ROOT)

    transform = lambda img, mask: NormalizePerPatch()(*RandomCrop25D(crop_hw=(256,256), num_slices=32)(img, mask))
    ds = Patch25DDataset(base_ds, transform=transform)

    n = len(ds)
    n_train = int(0.9 * n)
    n_val = n - n_train
    train_ds, val_ds = random_split(ds, [n_train, n_val])

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=False)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2, pin_memory=False)

    model = UNet2D(in_channels=32, base=32, out_channels=1).to(device)
    loss_fn = DiceBCE(bce_weight=0.5)
    opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)

    for epoch in range(1, 6):
        model.train()
        tr_loss = 0.0

        for x, y, _sid in train_loader:
            x = x.to(device)              # (B,C,H,W)
            y = y.to(device)              # (B,1,H,W)

            opt.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            opt.step()

            tr_loss += loss.item()

        tr_loss /= max(1, len(train_loader))

        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for x, y, _sid in val_loader:
                x = x.to(device)
                y = y.to(device)
                logits = model(x)
                va_loss += loss_fn(logits, y).item()
        va_loss /= max(1, len(val_loader))

        print(f"Epoch {epoch:02d} | train {tr_loss:.4f} | val {va_loss:.4f}")

    torch.save(model.state_dict(), "unet25d_baseline.pt")
    print("Saved: unet25d_baseline.pt")

if __name__ == "__main__":
    main()

Device: mps


PicklingError: Can't pickle local object <function main.<locals>.<lambda> at 0x10c899bc0>