In [1]:
import sys
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
sys.path.append('..')
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D

In [3]:
train_dataset = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir ="./dataset/affine_transformed")
test_dataset = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir ="./dataset/affine_transformed")

In [4]:
model = UNet3D(1, 4)

In [5]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        images, masks = batch
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images.float())
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    # Validation after each epoch
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            images, masks = batch
            images, masks = images.to(device), masks.to(device)

            outputs = model(images.float())
            loss = criterion(outputs, masks)
            total_loss += loss.item()

    average_loss = total_loss / len(test_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), 'unet_model.pth')
model.load_state_dict(torch.load('unet_model.pth'))
model.eval()

In [None]:
def dice_coefficient(predicted, target):
    intersection = torch.sum(predicted * target)
    union = torch.sum(predicted) + torch.sum(target)
    dice = (2.0 * intersection) / (union + 1e-8)  # Add a small epsilon to avoid division by zero
    return dice.item()

# Testing loop
total_dice = 0.0
num_samples = len(test_loader.dataset)
with torch.no_grad():
    for batch in test_loader:
        images, masks = batch
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        predictions = (F.softmax(outputs, dim=1)[:, 1] > 0.5).float()  # Assuming binary segmentation

        # Calculate Dice coefficient for the batch
        dice = dice_coefficient(predictions, masks)
        total_dice += dice

# Average Dice coefficient over all batches
average_dice = total_dice / len(test_loader)

print(f"Average Dice Coefficient on Test Set: {average_dice:.4f}")