In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: Number of epochs with no improvement after which training will be stopped.
        :param min_delta: Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


In [None]:
# Gold standard with report and center cropping

import os
import torch
from monai.networks.nets import UNet
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ToTensord, RandRotate90d, CenterSpatialCropd
)
from monai.data import DataLoader, Dataset
from monai.losses import DiceLoss
from torch.optim import Adam
import matplotlib.pyplot as plt
from monai.handlers import StatsHandler, TensorBoardImageHandler
from monai.utils import set_determinism
from torch.utils.tensorboard import SummaryWriter

# Set determinism for reproducibility
set_determinism(seed=0)

# MONAI's transforms for dictionary format use 'd' at the end
# Assuming EnsureChannelFirstd is a valid replacement for AddChannel
print("Setting up directories and initial configurations...")


def create_dataset(data_dir):
    data_dicts = []
    for filename in os.listdir(data_dir):
        if filename.endswith("_Vx3.nrrd"):  # Identifies the image files
            image_path = os.path.join(data_dir, filename)
            label_filename = filename.replace("_Vx3.nrrd", "_Label.nrrd")  # Creates the label filename
            label_path = os.path.join(data_dir, label_filename)
            data_dicts.append({'image': image_path, 'label': label_path})
    return data_dicts

# Set the paths for the training and validation data directories
train_data_dir = "z:/W-People/Nate/Deep_Learning_Data/Train"
val_data_dir = "z:/W-People/Nate/Deep_Learning_Data/Validation"

# Desired sizes for cropping (not used in create_dataset but may be used elsewhere)
desired_height, desired_width, desired_depth = 128, 128, 128 

print("Creating datasets...")
train_files = create_dataset(train_data_dir)  # Only get files from the "Train" folder
val_files = create_dataset(val_data_dir)  

# Data Transformations
print("Defining transformations...")
roi_size = (desired_depth, desired_height, desired_width)  # Define the size of the cropped region
train_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    CenterSpatialCropd(keys=['image', 'label'], roi_size=roi_size), # Use random cropping
    RandRotate90d(keys=['image', 'label'], prob=0.5),
    ToTensord(keys=['image', 'label']),
]) 

val_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    CenterSpatialCropd(keys=['image', 'label'], roi_size=roi_size),  # Use random cropping for validation as well
    ToTensord(keys=['image', 'label']),
]) 

train_ds = Dataset(data=train_files, transform=train_transforms)
val_ds = Dataset(data=val_files, transform=val_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True) #collate_fn=pad_list_data_collate
val_loader = DataLoader(val_ds, batch_size=1) # collate_fn=pad_list_data_collate

# UNet Model Initialization

print("Initializing 3D U-Net model...")

model = UNet(
    spatial_dims=3,  # This specifies that the network should be 3D
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2)
)

# Loss Function and Optimizer
loss_function = DiceLoss(to_onehot_y=True, softmax=True) #to_onehot_y=True,
optimizer = Adam(model.parameters(), 1e-3) # This is the learning rate 
model_save_path = "z:/W-People/Nate/Deep_Learning_Data/Deep Learning Model/Nate_Unet.pth"
optimizer_save_path = "z:/W-People/Nate/Deep_Learning_Data/Deep Learning Model/Nate_Unet_optimzer.pth"
# Load the state dict into the model

if os.path.exists(model_save_path) and os.path.exists(optimizer_save_path):
    model.load_state_dict(torch.load(model_save_path))
    optimizer.load_state_dict(torch.load(optimizer_save_path))
    print("Loaded saved model and optimizer.")
else:
    print("No saved model or optimizer state found. Starting from scratch.")

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training and Validation Functions
def train_epoch(model, loader, optimizer, loss_function, device):
    model.train()
    epoch_loss = 0
    for batch_data in loader:
        inputs, targets = batch_data['image'], batch_data['label']
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)


def validate_epoch(model, loader, loss_function, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_data in loader:
            inputs, targets = batch_data['image'], batch_data['label']
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            epoch_loss += loss.item()
    return epoch_loss / len(loader)

# Main Training Loop
train_losses = []
val_losses = []

early_stopping = EarlyStopping(patience=20, min_delta=0.001)  # You can adjust these parameters


print("Starting training process...")
num_epochs = 20
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, loss_function, device)
    val_loss = validate_epoch(model, val_loader, loss_function, device)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
   


    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss}, Validation Loss: {val_loss}")

    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break


torch.save(model.state_dict(), model_save_path)
torch.save(optimizer.state_dict(), optimizer_save_path)
print(f"Model and optimizer states saved to {model_save_path} and {optimizer_save_path} respectively.")


# Plotting loss curves
print("Plotting loss curves...")
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Dice Loss')
plt.legend()
plt.show()


In [None]:
import numpy as np
def visualize_predictions(loader, model, device, num_images=3, slice_index=32):
    model.eval()
    with torch.no_grad():
        batch = next(iter(loader))
        inputs, targets = batch['image'], batch['label']
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        
        # Convert model output to binary predictions
        predicted_labels = outputs.argmax(dim=1, keepdim=True)

        for i in range(min(num_images, len(inputs))):
            plt.figure(figsize=(18, 6))

            # Selecting a slice to display
            input_slice = np.squeeze(inputs[i].cpu()[0, slice_index, :, :])
            target_slice = np.squeeze(targets[i].cpu()[0, slice_index, :, :])
            predicted_slice = np.squeeze(predicted_labels[i].cpu()[0, slice_index, :, :])

            plt.subplot(1, 3, 1)
            plt.imshow(input_slice, cmap='gray')
            plt.title('Original Image Slice')

            plt.subplot(1, 3, 2)
            plt.imshow(target_slice, cmap='gray')
            plt.title('True Label Slice')

            plt.subplot(1, 3, 3)
            plt.imshow(predicted_slice, cmap='gray')
            plt.title('Predicted Label Slice')

            plt.show()

visualize_predictions(val_loader, model, device)


In [None]:
def overlay_predictions(batch, predictions, alpha=0.3, num_images=3, slice_index=32):
    images, labels = batch['image'], batch['label']
    for i in range(min(num_images, len(images))):
        plt.figure(figsize=(8, 8))

        # Selecting a slice to display
        image_slice = np.squeeze(images[i][0, slice_index, :, :])
        prediction_slice = np.squeeze(predictions[i][0, slice_index, :, :])

        plt.imshow(image_slice, cmap='gray')
        plt.imshow(prediction_slice, cmap='winter', alpha=alpha)
        plt.title('Overlay of Prediction on Original Image Slice')
        plt.show()

# Create overlay visualizations
batch_data = next(iter(val_loader))
model.eval()
with torch.no_grad():
    inputs = batch_data['image'].to(device)
    outputs = model(inputs)
    predicted_labels = outputs.argmax(dim=1, keepdim=True).cpu()
    overlay_predictions(batch_data, predicted_labels)


In [None]:
from mayavi import mlab
import numpy as np

def visualize_predictions_3d(loader, model, device, num_volumes=1):
    model.eval()
    with torch.no_grad():
        batch = next(iter(loader))
        inputs, targets = batch['image'], batch['label']
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)

        # Convert model output to binary predictions
        predicted_labels = outputs.argmax(dim=1, keepdim=True)

        for i in range(min(num_volumes, len(inputs))):
            # Reconstruct 3D volumes
            input_volume = inputs[i].cpu().squeeze().numpy()
            target_volume = targets[i].cpu().squeeze().numpy()
            predicted_volume = predicted_labels[i].cpu().squeeze().numpy()

            # Set up a figure
            fig = mlab.figure(size=(800, 800), bgcolor = (0,0,0))

            # Visualize original image volume
            #mlab.contour3d(input_volume, contours=[input_volume.max()/2], color=(0, 0, 1), transparent=True, figure=fig)

            # Visualize true label volume
            mlab.contour3d(target_volume, contours=[target_volume.max()/2], color=(0, 1, 0), transparent=True, figure=fig)

            # Visualize predicted label volume
            # mlab.contour3d(predicted_volume, contours=[predicted_volume.max()/2], color=(1, 0, 0), transparent=True, figure=fig)

            # Display the visualization
            mlab.show()

# Call the function with appropriate parameters
visualize_predictions_3d(val_loader, model, device)



# Problems to be fixed
# input volume should be ct image,,, when showing that in mayavi WHAT ARE YOU ACTUALLY SHOWING?
# do use contour 3d
# Add early stopping make epochs 200 
# first experiment add 5 images in validation train with two images
# try to get the best thing we can get just from two images first
# use the exact same 5 for validation
# ask chat gpt for creating a report (capabilities in monai)
# Change learning rate maybe add adaptive learning rate


In [None]:
from mayavi import mlab
import torch
import numpy as np

def visualize_predictions_3d(loader, model, device, num_volumes=1):
    model.eval()
    with torch.no_grad():
        batch = next(iter(loader))
        inputs, targets = batch['image'], batch['label']
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)

        # Convert model output to binary predictions
        predicted_labels = outputs.argmax(dim=1, keepdim=True)

        for i in range(min(num_volumes, len(inputs))):
            # Reconstruct 3D volumes
            input_volume = inputs[i].cpu().squeeze().numpy()
            target_volume = targets[i].cpu().squeeze().numpy()
            predicted_volume = predicted_labels[i].cpu().squeeze().numpy()

            # Set up a figure
            fig = mlab.figure(size=(800, 800), bgcolor=(0, 0, 0))

            # Visualize true label volume
            target_src = mlab.pipeline.scalar_field(target_volume)
            mlab.pipeline.iso_surface(target_src, contours=[target_volume.max() * 0.5], 
                                      colormap='winter', opacity=.5, figure=fig)

            # Visualize predicted label volume
            predicted_src = mlab.pipeline.scalar_field(predicted_volume)
            mlab.pipeline.iso_surface(predicted_src, contours=[predicted_volume.max() * 0.5], 
                                      colormap='rainbow', opacity=1, figure=fig)

            # Add legends and annotations
            mlab.text(0.01, 0.01, "Target Volume", color=(0, 0, 1), width=0.15, figure=fig)
            mlab.text(0.01, 0.95, "Predicted Volume", color=(1, 0, 0), width=0.15, figure=fig)

            # Display the visualization
            mlab.show()

# Assuming val_loader, model, and device have been defined elsewhere in your script:
visualize_predictions_3d(val_loader, model, device)


In [None]:
from mayavi import mlab
import numpy as np

def visualize_predictions_3d(loader, model, device, num_volumes=1):
    model.eval()
    with torch.no_grad():
        batch = next(iter(loader))
        inputs, targets = batch['image'], batch['label']
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)

        # Convert model output to binary predictions
        predicted_labels = outputs.argmax(dim=1, keepdim=True)

        for i in range(min(num_volumes, len(inputs))):
            # Reconstruct 3D volumes
            input_volume = inputs[i].cpu().squeeze().numpy()
            target_volume = targets[i].cpu().squeeze().numpy()
            predicted_volume = predicted_labels[i].cpu().squeeze().numpy()

            target_volume = target_volume.astype(np.float32)
            predicted_volume = predicted_volume.astype(np.float32)

            # Set up a figure
            fig = mlab.figure(size=(800, 800), bgcolor=(0, 0, 0))

            # Visualize original image volume
            # For better visualization, you may want to preprocess your volume
            # to set a threshold or adjust intensities.
            # input_volume_processed = preprocess_volume(input_volume)
            # mlab.pipeline.volume(mlab.pipeline.scalar_field(input_volume_processed), figure=fig)

            # Visualize true label volume
            true_volume_src = mlab.pipeline.scalar_field(target_volume)
            mlab.pipeline.volume(true_volume_src, vmin=0, vmax=target_volume.max(), figure=fig)

            # Visualize predicted label volume
            predicted_volume_src = mlab.pipeline.scalar_field(predicted_volume)
            mlab.pipeline.volume(predicted_volume_src, vmin=0, vmax=predicted_volume.max(), figure=fig)

            # Display the visualization
            mlab.show()

# Call the function with appropriate parameters
visualize_predictions_3d(val_loader, model, device)
