In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from loss import VolumeAwareLoss
import torch.nn.functional as F
from monai.utils import set_determinism
import datetime

# Set deterministic behavior for reproducibility
set_determinism(seed=42)

def create_sample_data(shape=(2, 5, 16, 16, 16)):
    """Create sample 3D data for testing loss function"""
    # Create empty tensors
    batch_size, n_classes, d, h, w = shape
    ground_truth = torch.zeros((batch_size, d, h, w), dtype=torch.long)
    
    # Sample 1: Create tumor components
    # NETC (class 1): Sphere in center
    center = np.array([d//2, h//2, w//2])
    for i in range(d):
        for j in range(h):
            for k in range(w):
                if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 9:
                    ground_truth[0, i, j, k] = 1
    
    # SNFH (class 2): Region around the NETC
    for i in range(d):
        for j in range(h):
            for k in range(w):
                if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) >= 9 and \
                   np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 25:
                    ground_truth[0, i, j, k] = 2
    
    # ET (class 3): Small enhancing regions
    centers = [np.array([d//2-3, h//2-3, w//2]), np.array([d//2+3, h//2+3, w//2])]
    for center in centers:
        for i in range(d):
            for j in range(h):
                for k in range(w):
                    if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 2:
                        ground_truth[0, i, j, k] = 3
    
    # RC (class 4): Resection cavity
    center = np.array([d//2, h//2, w//2-5])
    for i in range(d):
        for j in range(h):
            for k in range(w):
                if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 4:
                    ground_truth[0, i, j, k] = 4
    
    # Sample 2: Different arrangement
    # NETC (class 1): Off-center region
    center = np.array([d//3, h//3, w//3])
    for i in range(d):
        for j in range(h):
            for k in range(w):
                if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 7:
                    ground_truth[1, i, j, k] = 1
    
    # SNFH (class 2): Larger surrounding region
    for i in range(d):
        for j in range(h):
            for k in range(w):
                dist = np.sum(np.square([i-center[0], j-center[1], k-center[2]]))
                if dist >= 7 and dist < 20:
                    ground_truth[1, i, j, k] = 2
    
    # ET (class 3): Multiple small regions
    centers = [np.array([d//4, h//4, w//4]), np.array([d//2, h//2, w//2])]
    for center in centers:
        for i in range(d):
            for j in range(h):
                for k in range(w):
                    if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 2:
                        ground_truth[1, i, j, k] = 3
    
    # RC (class 4): Small cavity
    center = np.array([2*d//3, 2*h//3, 2*w//3])
    for i in range(d):
        for j in range(h):
            for k in range(w):
                if np.sum(np.square([i-center[0], j-center[1], k-center[2]])) < 3:
                    ground_truth[1, i, j, k] = 4
    
    # Convert to one-hot encoding
    ground_truth_onehot = F.one_hot(ground_truth, n_classes).permute(0, 4, 1, 2, 3).float()
    
    return ground_truth, ground_truth_onehot

def create_predictions(ground_truth_onehot):
    """Create three types of predictions for testing"""
    batch_size, n_classes, d, h, w = ground_truth_onehot.shape
    
    # Case 1: Random predictions
    random_logits = torch.randn((batch_size, n_classes, d, h, w)) * 2
    random_pred = F.softmax(random_logits, dim=1)
    
    # Case 2: Noisy ground truth
    noise_level = 0.3
    noise = torch.randn((batch_size, n_classes, d, h, w)) * noise_level
    noisy_pred = ground_truth_onehot + noise
    noisy_pred = torch.clamp(noisy_pred, 0, 1)
    noisy_pred = noisy_pred / noisy_pred.sum(dim=1, keepdim=True)
    
    # Case 3: Identical to ground truth
    identical_pred = ground_truth_onehot.clone()
    
    return random_logits, noisy_pred, identical_pred

def visualize_slices(ground_truth_onehot, prediction, title, slice_idx=None):
    """Visualize middle slices of ground truth and prediction"""
    batch_idx = 0  # First sample in batch
    
    if slice_idx is None:
        d = ground_truth_onehot.shape[2]
        slice_idx = d // 2
    
    n_classes = ground_truth_onehot.shape[1]
    fig, axes = plt.subplots(2, n_classes, figsize=(15, 6))
    
    plt.suptitle(f"{title} - Slice {slice_idx}", fontsize=16)
    
    class_names = ["Background", "NETC", "SNFH", "ET", "RC"]
    
    for c in range(n_classes):
        gt_slice = ground_truth_onehot[batch_idx, c, slice_idx]
        pred_slice = prediction[batch_idx, c, slice_idx]
        
        axes[0, c].imshow(gt_slice, cmap='viridis', vmin=0, vmax=1)
        axes[0, c].set_title(f"GT: {class_names[c]}")
        axes[0, c].axis('off')
        
        im = axes[1, c].imshow(pred_slice, cmap='viridis', vmin=0, vmax=1)
        axes[1, c].set_title(f"Pred: {class_names[c]}")
        axes[1, c].axis('off')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    
    plt.show()

def compute_and_print_losses(loss_fn, pred, target, case_name):
    """Compute losses and print results"""
    print(f"\n{case_name} Prediction Results:")
    print("-" * 50)
    
    loss_dict = loss_fn(pred, target)
    
    print(f"Composite Loss: {loss_dict['loss'].item():.6f}")
    print(f"DiceCE Loss: {loss_dict['dice_ce_loss'].item():.6f}")
    print(f"Tversky Loss: {loss_dict['tversky_loss'].item():.6f}")
    print(f"Surface Loss: {loss_dict['surface_loss'].item():.6f}")
    
    print("\nPer-sample losses:")
    for i, sample_loss in enumerate(loss_dict['per_sample_loss']):
        print(f"  Sample {i} Loss: {sample_loss.item():.6f}")
    
    print("\nNormalized Class Weights:")
    class_names = ["Background", "NETC", "SNFH", "ET", "RC"]
    for i, sample_weights in enumerate(loss_dict['normalized_weights']):
        print(f"  Sample {i}:")
        for c, weight in enumerate(sample_weights):
            print(f"    {class_names[c]}: {weight.item():.4f}")
    
    return loss_dict

def main():
    print(f"Trial Run Time: 2025-07-12 04:07:21")
    print(f"User: twi-exe")
    print("-" * 50)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create sample data
    ground_truth, ground_truth_onehot = create_sample_data()
    print(f"Created sample data with shape: {ground_truth.shape}")
    print(f"One-hot ground truth shape: {ground_truth_onehot.shape}")
    
    # Move data to device
    ground_truth = ground_truth.to(device)
    ground_truth_onehot = ground_truth_onehot.to(device)
    
    # Create predictions
    random_pred, noisy_pred, identical_pred = create_predictions(ground_truth_onehot)
    random_pred = random_pred.to(device)
    noisy_pred = noisy_pred.to(device)
    identical_pred = identical_pred.to(device)
    
    # Create loss function
    loss_fn = VolumeAwareLoss(
        include_background=True,
        to_onehot_y=True,
        softmax=True,
        tversky_alpha=0.3,
        tversky_beta=0.7,
        class_weights=[1.0, 2.0, 1.5, 2.5, 1.5],  # BG, NETC, SNFH, ET, RC
        baseline_volumes=[0.0, 500.0, 2000.0, 250.0, 300.0],
    )
    
    
    # Apply softmax to random predictions for visualization
    random_pred_softmax = F.softmax(random_pred, dim=1)
    
    print("\n1. Random Prediction Test")
    compute_and_print_losses(loss_fn, random_pred, ground_truth, "Random")
    visualize_slices(ground_truth_onehot, random_pred_softmax, "Random Predictions")
    visualize_slices(ground_truth_onehot, random_pred_softmax, "Random Predictions", slice_idx=10)

    print("\n2. Noisy Prediction Test")
    compute_and_print_losses(loss_fn, noisy_pred, ground_truth, "Noisy")    
    visualize_slices(ground_truth_onehot, noisy_pred, "Noisy Predictions")
    visualize_slices(ground_truth_onehot, noisy_pred, "Noisy Predictions", slice_idx=10)

    print("\n3. Identical Prediction Test")
    compute_and_print_losses(loss_fn, identical_pred, ground_truth, "Identical")
    visualize_slices(ground_truth_onehot, identical_pred, "Identical Predictions")
    visualize_slices(ground_truth_onehot, identical_pred, "Identical Predictions", slice_idx=10)

if __name__ == "__main__":
    main()

Trial Run Time: 2025-07-12 04:07:21
User: twi-exe
--------------------------------------------------
Using device: cuda
Created sample data with shape: torch.Size([2, 16, 16, 16])
One-hot ground truth shape: torch.Size([2, 5, 16, 16, 16])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!