In [None]:
import os

from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from utils import BrainDataset

# Paths to NIfTI files
# list all files in folder data MSSEG-1
dataset_path = "../data/MSSEG-1-preprocessed-2/"

# Create the dataset and dataloader
dataset = BrainDataset(root_dir=dataset_path)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))


In [None]:
# Fetch a sample from the dataset
images, targets = next(iter(data_loader))

# Select the first sample
image = images[0].squeeze(0).numpy()  # Remove channel dimension
mask = targets[0]['masks'][0].numpy()  # Select the first mask

# Plot the image slice and the corresponding mask slice
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image, cmap='gray')
axes[0].set_title("Flair Image Slice")
axes[0].axis("off")

axes[1].imshow(mask, cmap='gray')
axes[1].set_title("Consensus Mask Slice")
axes[1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
import torchvision

num_classes = 2  # Background + 1 class for segmentation

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

# Modify the classifier and mask head for the new dataset
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(    in_features_mask, hidden_layer, num_classes)

In [None]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
num_epochs = 10

In [None]:
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, targets in data_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)

        # Check for NaNs in the loss dictionary
        for key, value in loss_dict.items():
            if torch.isnan(value).any():
                print(f"NaN detected in {key}")

        losses = sum(loss for loss in loss_dict.values())
        if torch.isnan(losses):
            print("NaN in overall loss. Skipping batch.")
            continue  # Skip this batch

        epoch_loss += losses.item()
        losses.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    # Save the model if the current epoch's loss is better than the best loss
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Model saved with loss: {best_loss:.4f}")