In [13]:
import matplotlib.pyplot as plt
import monai
import numpy as np
import torch
from monai.data import DataLoader, Dataset
from pathlib import Path
from monai.transforms.utils import allow_missing_keys_mode
from monai.transforms import BatchInverseTransform
from monai.networks.nets import ResNet
import nibabel as nib
from tqdm import tqdm
import medpy.metric as metric
import os
import dxchange

In [48]:
def getBugData(dataset_path: Path):
    dataset = []
    for idx, item in enumerate(os.listdir(dataset_path)):
        one_hot_v = np.zeros(12)
        one_hot_v[idx] = 1
        if item == "BC" or item == "BF" or item == "BL":
            for file in os.listdir(str(dataset_path) + "/"+ item):
                dataset.append({'image':str(dataset_path) + "/"+ item + "/" + file,
                                 "class": str(item),
                                 "label": one_hot_v})
    return dataset

DATA_PATH = "/dtu/3d-imaging-center/courses/02510/data/Bugs/bugnist_128/"

# 1. Data. Make a 70-10-20% train-validation-test split here
Files = getBugData(dataset_path=Path(DATA_PATH))
print(Files[0])
#valFiles = getBugData(dataset_path=Path(DATA_PATH))  
#testFiles = getBugData(dataset_path=Path(DATA_PATH))

{'image': '/dtu/3d-imaging-center/courses/02510/data/Bugs/bugnist_128/BC/sfaar_16_003.tif', 'class': 'BC', 'label': array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}


In [59]:

train_transforms = monai.transforms.Compose([
    monai.transforms.LoadImage(keys=['image'], reader="nibabelreader"),
    monai.transforms.EnsureChannelFirstd(keys=['image']),
])


BATCH_SIZE = 2
train_dataset = Dataset(data=Files, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
#val_dataset = Dataset(data=valFiles, transform=val_transforms)
#val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# 2. Model. Now, use your model to do inference in a few images
for i, dummy_data in enumerate(train_loader):
    print(dummy_data['image'].shape)
    break

RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7fad50bf3e50>

In [None]:
# Hyper-parameters (next three lines) #
NUM_EPOCHS = 10
EVAL_EVERY = 1
BATCH_SIZE = 2

train_dataset = Dataset(data=Files, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = None
val_loader = None

model = None


# More design decisions (model, loss, optimizer) #
loss_fn = torch.nn.CrossEntropyLoss() # Apply "softmax" to the output of the network and don't convert to onehot because this is done already by the transforms.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 

inferer = monai.inferers.SliceInferer(roi_size=[-1, -1], spatial_dim=2, sw_batch_size=1)

train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch + 1}')

    model.train()
    epoch_loss = 0
    step = 0
    for tr_data in train_loader:
        inputs = tr_data['image'].cuda()
        targets = tr_data['label'].cuda()

        # Forward -> Backward -> Step
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.detach()
        step += 1
        
    # Log and store average epoch loss
    epoch_loss = epoch_loss.item() / step
    train_losses.append(epoch_loss)
    print(f'Mean training loss: {epoch_loss}')

    epoch_loss = 0
    step = 0
    if epoch % EVAL_EVERY == 0:
        model.eval()
        with torch.no_grad():  # Do not need gradients for this part
            for val_data in val_loader:
                inputs = val_data['image'].cuda()
                targets = val_data['label'].cuda()

                outputs = model(inputs)

                loss = loss_fn(outputs, targets)
                epoch_loss += loss.detach()
                step += 1
        
        # Log and store average epoch loss
        epoch_loss = epoch_loss.item() / step
        val_losses.append(epoch_loss)
        print(f'Mean validation loss: {epoch_loss}')


In [None]:
# Code for the task here
# Plot the training loss over time
plt.plot(train_losses, label='Training')
plt.plot(val_losses, label='Validation')
plt.title('Training loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
