In [18]:
!pip install monai nibabel




In [22]:
from transforms import get_transforms


In [23]:
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd,
    ScaleIntensityd, RandCropByPosNegLabeld, ToTensord
)

def get_transforms():
    return Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityd(keys="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(64, 64, 64),  # ðŸ”´ MUST MATCH DUMMY DATA
            pos=1,
            neg=1,
            num_samples=1
        ),
        ToTensord(keys=["image", "label"])
    ])


In [19]:
import torch
from monai.data import CacheDataset, DataLoader
from monai.losses import DiceLoss
from tqdm import tqdm

from model import get_model
from transforms import get_transforms
from utils import device

print("Running on:", device)


Running on: cpu


In [20]:
data = [
    {"image": "sample.nii.gz",
     "label": "sample_seg.nii.gz"}
]



In [24]:
dataset = CacheDataset(data, transform=get_transforms(), cache_rate=1.0)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

model = get_model().to(device)
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

for epoch in range(2):  # small for CPU
    model.train()
    epoch_loss = 0
    for batch in tqdm(loader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        loss = loss_fn(model(images), labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

torch.save(model.state_dict(), "model.pth")


Loading dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1/1 [00:00<00:00, 27.78it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1/1 [00:01<00:00,  1.71s/it]


Epoch 1, Loss: 0.6626


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1/1 [00:01<00:00,  1.78s/it]

Epoch 2, Loss: 0.6596



