In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from monai.networks.nets import UNet

from utils.loader import DicomDataset3D

BATCH_SIZE = 2
EPOCHS = 60
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Use MONAI's 3D UNet implementation
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
    num_res_units=2,
).to(device)

train_dataloader = DataLoader(DicomDataset3D("data/train.csv"), batch_size=BATCH_SIZE)
test_dataloader = DataLoader(DicomDataset3D("data/test.csv"), batch_size=BATCH_SIZE)

optimizer = torch.optim.Adam(model.parameters(), lr=0.04)

pos_weight = torch.tensor([70]).to(device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)


In [None]:
for epoch in range(EPOCHS):
    for data in iter(train_dataloader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        output = model(inputs)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()
    print(loss)

torch.save(model.state_dict(), './state_dicts/monai_unet.pk')

In [None]:
import utils.notebooks as nb

with torch.no_grad():
    x, y = next(iter(train_dataloader))
    if torch.cuda.is_available():
        x, y = x.cuda(), y.cuda()
    pred = model(x)
    pred = torch.nn.Sigmoid()(pred)
    print(pred.shape)
    nb.show(x[0], 45)
    nb.show(pred[0], 45)
    nb.show(y[0], 45)
    print(torch.min(pred))
    print(torch.max(pred))