# __Import & config__

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

In [2]:
import os
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from celluloid import Camera
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.animation as animation

from src.utils.config import load_config
from src.utils.evaluation import load_trained_model
from src.inference.predicter import Predicter2D, Predicter3D

# __Functions__

In [3]:
def predict_and_visualize(patient_id, predicter, output_dir, mode="2D"):
    """
    Predicts a patient's data, visualizes the results, and exports them.

    Parameters
    ----------
    patient_id : str
        ID of the patient to predict.
    predicter : Predicter2D or Predicter3D
        The predicter object to use for predictions.
    output_dir : str
        Directory to save the outputs.
    mode : str, optional
        Mode of prediction ("2D" or "3D"), by default "2D".
    """
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Perform prediction
    print(f"🔍 Predicting patient {patient_id}...")
    predictions, masks = predicter.predict_patient(patient_id)

    # Convert predictions to NumPy arrays
    predictions = predictions.squeeze(0).cpu().numpy()  # (C, D, H, W)
    masks = masks.squeeze(0).cpu().numpy()              # (1, D, H, W)

    # Visualize predictions and masks as an animation
    visualize_animation(predictions, masks, output_dir, mode)

    # Visualize predictions as a 3D volume
    visualize_3d(predictions, output_dir)

    # Export predictions and masks to NIfTI
    export_to_nifti(predictions, masks, output_dir, patient_id)

def visualize_animation(predictions, masks, output_dir, mode):
    """
    Creates an animation of predictions and masks side by side.

    Parameters
    ----------
    predictions : np.ndarray
        Predicted volumes (C, D, H, W).
    masks : np.ndarray
        Ground truth masks (1, D, H, W).
    output_dir : str
        Directory to save the animation.
    mode : str
        Mode of prediction ("2D" or "3D").
    """
    print("🎥 Creating animation...")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    camera = Camera(fig)

    # Iterate over slices
    for i in range(predictions.shape[1]):  # Iterate over depth (D)
        pred_slice = predictions[0, i, :, :]  # Channel 0 for visualization
        mask_slice = masks[0, i, :, :]

        ax[0].imshow(pred_slice, cmap="viridis")
        ax[0].set_title("Prediction")
        ax[1].imshow(mask_slice, cmap="gray")
        ax[1].set_title("Ground Truth Mask")

        camera.snap()

    animation = camera.animate()
    animation.save(os.path.join(output_dir, f"{mode}_animation.mp4"))
    print(f"✅ Animation saved to {output_dir}")

def visualize_3d(predictions, output_dir):
    """
    Creates a 3D visualization of the predicted volume.

    Parameters
    ----------
    predictions : np.ndarray
        Predicted volumes (C, D, H, W).
    output_dir : str
        Directory to save the 3D visualization.
    """
    print("📊 Creating 3D visualization...")
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")

    # Use the first channel for visualization
    volume = predictions[0]  # Shape: (D, H, W)

    # Create a 3D scatter plot
    depth, height, width = volume.shape
    x, y, z = np.meshgrid(
        np.arange(width), np.arange(height), np.arange(depth)
    )
    ax.scatter(x, y, z, c=volume.flatten(), cmap="viridis", alpha=0.5)

    ax.set_title("3D Volume Visualization")
    plt.savefig(os.path.join(output_dir, "3d_visualization.png"))
    print(f"✅ 3D visualization saved to {output_dir}")

def export_to_nifti(predictions, masks, output_dir, patient_id):
    """
    Exports predictions and masks to NIfTI format.

    Parameters
    ----------
    predictions : np.ndarray
        Predicted volumes (C, D, H, W).
    masks : np.ndarray
        Ground truth masks (1, D, H, W).
    output_dir : str
        Directory to save the NIfTI files.
    patient_id : str
        ID of the patient.
    """
    print("💾 Exporting to NIfTI...")
    pred_nifti = nib.Nifti1Image(predictions[0], affine=np.eye(4))
    mask_nifti = nib.Nifti1Image(masks[0], affine=np.eye(4))

    pred_path = os.path.join(output_dir, f"{patient_id}_predictions.nii.gz")
    mask_path = os.path.join(output_dir, f"{patient_id}_masks.nii.gz")

    nib.save(pred_nifti, pred_path)
    nib.save(mask_nifti, mask_path)

# __Test__

In [3]:
config = load_config('configs/experiments/deep_aug_5.yaml')
model_path = 'experiments/deep_aug/deep_aug_20250415_215856/checkpoints/best_model_epoch60.pth'
model = load_trained_model(config, model_path)
config_device = config['training']['device']
device = torch.device(config_device if torch.cuda.is_available() else "cpu")
test_dir = 'data/processed/2d/train/'
patient_ids = ["rtum79", "rtum1", "rtum33", "rtum3", "rtum20", "rtum70", "rtum19", "rtum26", "rtum13", "rtum71", "rtum87", "rtum69", "rtum58", "rtum82", "rtum86", "rtum68", "rtum4", "rtum81"]

In [4]:
predicter = Predicter2D(model, config, device, test_dir)

In [5]:
patient_id = 'rtum79'
predictions, masks = predicter.predict_patient(patient_id)
print(f"predictions: {predictions.shape}")
print(f"masks: {masks.shape}")

  0%|          | 0/26 [00:00<?, ?it/s]

predictions: torch.Size([1, 5, 103, 256, 256])
masks: torch.Size([1, 103, 256, 256])


In [13]:
def create_2d_animation(predictions, ground_truth, patient_id, output_dir):
    """
    Creates an animation comparing predictions and ground truth masks
    
    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions (B, C, D, H, W)
    ground_truth : torch.Tensor
        Ground truth masks (B, D, H, W)
    patient_id : str
        Patient ID
    output_dir : str
        Output directory
    """
    # Create colormap for visualization
    cmap = ListedColormap(['green', 'purple', 'red', 'blue'])
    boundaries = [0.5, 1.5, 2.5, 3.5, 4.5]
    norm = BoundaryNorm(boundaries, cmap.N, clip=True)
    
    # Convert tensors to numpy
    pred_numpy = torch.argmax(predictions[0], dim=0).cpu().numpy()
    gt_numpy = ground_truth[0].cpu().numpy()
    
    # Extract dimensions
    D, H, W = pred_numpy.shape
    
    # Create the figure
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Create black background canvases
    gt_bg = np.zeros((H, W))
    pred_bg = np.zeros((H, W))
    
    # Initialize the plots with the first slice
    gt_img = axes[0].imshow(gt_bg, cmap='gray')
    gt_overlay = axes[0].imshow(
        np.where(gt_numpy[0] > 0, gt_numpy[0], np.nan), 
        cmap=cmap, norm=norm, alpha=0.8
    )
    axes[0].set_title('Ground Truth', fontsize=12)
    axes[0].axis('off')
    
    pred_img = axes[1].imshow(pred_bg, cmap='gray')
    pred_overlay = axes[1].imshow(
        np.where(pred_numpy[0] > 0, pred_numpy[0], np.nan), 
        cmap=cmap, norm=norm, alpha=0.8
    )
    axes[1].set_title('Prediction', fontsize=12)
    axes[1].axis('off')
    
    # Add slice information
    title = fig.suptitle(f"Patient {patient_id} - Slice 0", fontsize=14)
    
    # Create update function for animation
    def update(frame):
        # Update ground truth
        gt_overlay.set_array(np.where(gt_numpy[frame] > 0, gt_numpy[frame], np.nan))
        
        # Update prediction
        pred_overlay.set_array(np.where(pred_numpy[frame] > 0, pred_numpy[frame], np.nan))
        
        # Update title
        title.set_text(f"Patient {patient_id} - Slice {frame}")
        
        return [gt_img, gt_overlay, pred_img, pred_overlay, title]
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, update, frames=D, interval=200, blit=False
    )
    
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save animation
    output_path = os.path.join(output_dir, f"patient_{patient_id}_2d_animation.gif")
    
    # Use PillowWriter directly
    writer = animation.PillowWriter(fps=5)
    anim.save(output_path, writer=writer)
    
    plt.close(fig)  # Close the figure to free memory
    
    print(f"2D animation saved to {output_path}")
    return output_path

In [None]:
def process_tensors(predictions, ground_truth, argmax=True):
    if predictions.dim() == 5:
        predictions = predictions.squeeze(0)
        if argmax:
            predictions = torch.argmax(predictions, dim=0)
        else:
            predictions = predictions.squeeze(0)
    if ground_truth.dim() == 4:
        ground_truth = ground_truth.unsqueeze(0)

    pred_np = predictions.cpu().numpy()
    gt_np = ground_truth.cpu().numpy()

    return pred_np, gt_np

In [12]:
output_dir = 'test_predicter'
create_2d_animation(predictions, masks, 'rtum79_test', output_dir)

shapes before
pred: torch.Size([1, 5, 103, 256, 256])
gt: torch.Size([1, 103, 256, 256])
shapes after
pred: (103, 256, 256)
gt: (103, 256, 256)


In [22]:
def create_3d_visualization(predictions, ground_truth, patient_id, output_dir):
    """
    Creates 3D visualizations of predictions and ground truth
    
    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions (B, C, D, H, W)
    ground_truth : torch.Tensor
        Ground truth masks (B, D, H, W)
    patient_id : str
        Patient ID
    output_dir : str
        Output directory
    """
    # Convert tensors to numpy
    pred_numpy = torch.argmax(predictions[0], dim=0).cpu().numpy()
    gt_numpy = ground_truth[0].cpu().numpy()
    
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Create figure for 3D visualization
    fig = plt.figure(figsize=(15, 7))
    
    # Ground truth subplot
    ax1 = fig.add_subplot(121, projection='3d')
    ax1.set_title('Ground Truth 3D', fontsize=14)
    
    # Prediction subplot
    ax2 = fig.add_subplot(122, projection='3d')
    ax2.set_title('Prediction 3D', fontsize=14)
    
    # Colors for different classes
    colors = ['green', 'purple', 'red', 'blue']
    class_names = ['Pancreas', 'Tumor', 'Arteries', 'Veins']
    
    # Plot each class
    for class_idx in range(1, 5):
        # Ground truth visualization
        voxels_gt = gt_numpy == class_idx
        if np.any(voxels_gt):
            # Get mask points, ensuring there are some
            z_gt, y_gt, x_gt = np.where(voxels_gt)
            
            # Determine appropriate sampling for smooth visualization
            sample_size = min(10000, len(z_gt))  # Cap to 10k points for performance
            if len(z_gt) > sample_size:
                indices = np.random.choice(len(z_gt), sample_size, replace=False)
                z_gt, y_gt, x_gt = z_gt[indices], y_gt[indices], x_gt[indices]
            
            # Plot points
            ax1.scatter(x_gt, y_gt, z_gt, c=colors[class_idx-1], alpha=0.7, s=5, 
                       label=f"{class_names[class_idx-1]}")
        
        # Prediction visualization
        voxels_pred = pred_numpy == class_idx
        if np.any(voxels_pred):
            # Get mask points, ensuring there are some
            z_pred, y_pred, x_pred = np.where(voxels_pred)
            
            # Determine appropriate sampling for smooth visualization
            sample_size = min(10000, len(z_pred))  # Cap to 10k points for performance
            if len(z_pred) > sample_size:
                indices = np.random.choice(len(z_pred), sample_size, replace=False)
                z_pred, y_pred, x_pred = z_pred[indices], y_pred[indices], x_pred[indices]
                
            # Plot points
            ax2.scatter(x_pred, y_pred, z_pred, c=colors[class_idx-1], alpha=0.7, s=5,
                      label=f"{class_names[class_idx-1]}")
    
    # Set labels and legends
    for ax in [ax1, ax2]:
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.legend(loc='upper right')
        
        # Ensure equal aspect ratio for better visualization
        max_range = np.array([
            ax.get_xlim()[1] - ax.get_xlim()[0],
            ax.get_ylim()[1] - ax.get_ylim()[0],
            ax.get_zlim()[1] - ax.get_zlim()[0]
        ]).max() / 2.0
        
        mid_x = (ax.get_xlim()[1] + ax.get_xlim()[0]) / 2
        mid_y = (ax.get_ylim()[1] + ax.get_ylim()[0]) / 2
        mid_z = (ax.get_zlim()[1] + ax.get_zlim()[0]) / 2
        
        ax.set_xlim(mid_x - max_range, mid_x + max_range)
        ax.set_ylim(mid_y - max_range, mid_y + max_range)
        ax.set_zlim(mid_z - max_range, mid_z + max_range)
        
    plt.suptitle(f"Patient {patient_id} - 3D Visualization", fontsize=16)
    
    # Save figure
    output_path = os.path.join(output_dir, f"patient_{patient_id}_3d_visualization.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)  # Close the figure to free memory
    
    print(f"3D visualization saved to {output_path}")
    return output_path

In [24]:
create_3d_visualization(predictions, masks, patient_id, output_dir)

3D visualization saved to test_predicter\patient_rtum79_3d_visualization.png


'test_predicter\\patient_rtum79_3d_visualization.png'

# __Test with volume__

In [31]:
def create_2d_animation(predictions, ground_truth, patient_id, output_dir, volume=None):
    """
    Creates an animation comparing predictions and ground truth masks,
    optionally overlaid on the CT volume, and saves it.

    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions (B, C, D, H, W)
    ground_truth : torch.Tensor
        Ground truth masks (B, D, H, W)
    patient_id : str
        Patient ID
    output_dir : str
        Output directory
    volume : torch.Tensor, optional
        CT volume (B, D, H, W). If provided, it will be used as grayscale background.
    """
    # Define colormap
    cmap = ListedColormap(['green', 'purple', 'red', 'blue'])
    norm = BoundaryNorm([0.5, 1.5, 2.5, 3.5, 4.5], cmap.N, clip=True)

    # Extract first batch
    pred_np = torch.argmax(predictions[0], dim=0).cpu().numpy()      # (D, H, W)
    gt_np = ground_truth[0].cpu().numpy()                            # (D, H, W)
    volume_np = volume[0][0].cpu().numpy() if volume is not None else None  # (D, H, W) or None

    D, H, W = pred_np.shape

    # Setup figure
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Use CT volume or black background
    background = volume_np[0] if volume_np is not None else np.zeros((H, W))
    
    # Initial slice
    gt_img = axes[0].imshow(background, cmap='gray')
    gt_overlay = axes[0].imshow(np.where(gt_np[0] > 0, gt_np[0], np.nan), cmap=cmap, norm=norm, alpha=0.5)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    pred_img = axes[1].imshow(background, cmap='gray')
    pred_overlay = axes[1].imshow(np.where(pred_np[0] > 0, pred_np[0], np.nan), cmap=cmap, norm=norm, alpha=0.5)
    axes[1].set_title('Prediction')
    axes[1].axis('off')

    title = fig.suptitle(f"Patient {patient_id} - Slice 0", fontsize=14)

    def update(frame):
        bg = volume_np[frame] if volume_np is not None else np.zeros((H, W))
        gt_img.set_array(bg)
        pred_img.set_array(bg)

        gt_overlay.set_array(np.where(gt_np[frame] > 0, gt_np[frame], np.nan))
        pred_overlay.set_array(np.where(pred_np[frame] > 0, pred_np[frame], np.nan))

        title.set_text(f"Patient {patient_id} - Slice {frame}")
        return [gt_img, gt_overlay, pred_img, pred_overlay, title]

    anim = animation.FuncAnimation(fig, update, frames=D, interval=200, blit=False)

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save animation
    output_path = os.path.join(output_dir, f"patient_{patient_id}_2d_animation.gif")
    
    # Use PillowWriter directly
    writer = animation.PillowWriter(fps=5)
    anim.save(output_path, writer=writer)
    
    plt.close(fig)  # Close the figure to free memory
    
    print(f"2D animation saved to {output_path}")
    return output_path

In [6]:
from src.data.dataset2d import PancreasDataset2D
from src.training.setup import get_transforms

In [7]:
ds = PancreasDataset2D(
    data_dir='data/processed/2d/train/',
    transform=get_transforms(config)
)

📊 Loading dataset... 8834 slices found.


In [8]:
volume, _ = ds.get_patient_volume(patient_id)
print(f"volume: {volume.shape}")

volume: torch.Size([1, 1, 103, 256, 256])


In [32]:
output_path = create_2d_animation(predictions, masks, 'rtum79_vol', output_dir='test_predicter', volume=volume)

2D animation saved to test_predicter\patient_rtum79_vol_2d_animation.gif


# __Test legend__

In [33]:
import matplotlib.patches as mpatches

LABEL_COLORS = {
    1: ("green", "Class 1"),
    2: ("purple", "Class 2"),
    3: ("red", "Class 3"),
    4: ("blue", "Class 4"),
}

In [55]:
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import numpy as np
import os
import torch

# Define colormap and class info
LABEL_COLORS = {
    1: ("green", "Class 1"),
    2: ("purple", "Class 2"),
    3: ("red", "Class 3"),
    4: ("blue", "Class 4"),
}
CMAP = ListedColormap([v[0] for v in LABEL_COLORS.values()])
BOUNDARIES = [0.5, 1.5, 2.5, 3.5, 4.5]
NORM = BoundaryNorm(BOUNDARIES, CMAP.N, clip=True)

def create_2d_animation(predictions, ground_truth, patient_id, output_dir, volume=None, alpha=0.6, filename=None):
    pred_np = torch.argmax(predictions[0], dim=0).cpu().numpy()
    gt_np = ground_truth[0].cpu().numpy()
    volume_np = volume[0].cpu().numpy() if volume is not None else None
    D, H, W = pred_np.shape

    # Setup figure layout
    fig = plt.figure(figsize=(10, 6))
    gs = gridspec.GridSpec(2, 2, height_ratios=[6, 0.3])

    ax_gt = fig.add_subplot(gs[0, 0])
    ax_pred = fig.add_subplot(gs[0, 1])
    ax_legend = fig.add_subplot(gs[1, :])  # Span both columns

    # Hide legend axes border
    ax_legend.axis('off')

    # Build the legend as patches
    legend_patches = [
        mpatches.Patch(color=color, label=label) for color, label in LABEL_COLORS.values()
    ]
    ax_legend.legend(handles=legend_patches, loc='center', ncol=len(LABEL_COLORS), fontsize=10)

    # Background and initial images
    bg = volume_np[0] if volume_np is not None else np.zeros((H, W))
    gt_img = ax_gt.imshow(bg, cmap='gray')
    gt_overlay = ax_gt.imshow(np.where(gt_np[0] > 0, gt_np[0], np.nan), cmap=CMAP, norm=NORM, alpha=alpha)
    ax_gt.set_title("Ground Truth", fontsize=14)
    ax_gt.axis("off")

    pred_img = ax_pred.imshow(bg, cmap='gray')
    pred_overlay = ax_pred.imshow(np.where(pred_np[0] > 0, pred_np[0], np.nan), cmap=CMAP, norm=NORM, alpha=alpha)
    ax_pred.set_title("Prediction", fontsize=14)
    ax_pred.axis("off")

    title = fig.suptitle(f"Patient {patient_id} - Slice 0", fontsize=16, fontweight="bold")

    def update(frame):
        bg = volume_np[frame] if volume_np is not None else np.zeros((H, W))
        gt_img.set_array(bg)
        pred_img.set_array(bg)
        gt_overlay.set_array(np.where(gt_np[frame] > 0, gt_np[frame], np.nan))
        pred_overlay.set_array(np.where(pred_np[frame] > 0, pred_np[frame], np.nan))
        title.set_text(f"Patient {patient_id} - Slice {frame}")
        return [gt_img, gt_overlay, pred_img, pred_overlay, title]

    anim = animation.FuncAnimation(fig, update, frames=D, interval=200, blit=False)

    os.makedirs(output_dir, exist_ok=True)
    if filename is None:
        filename = f"patient_{patient_id}_2d_animation.gif"
    output_path = os.path.join(output_dir, filename)

    writer = animation.PillowWriter(fps=5)
    anim.save(output_path, writer=writer)

    plt.close(fig)
    print(f"2D animation saved to {output_path}")
    return output_path

In [56]:
output_path = create_2d_animation(
        predictions,
        masks,
        'rtum79',
        output_dir='test_predicter',
        volume=None,
        filename='test_sin_volumen.gif',
        alpha=0.8
    )

2D animation saved to test_predicter\test_sin_volumen.gif


# __Another__

In [67]:
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import numpy as np
import os
import torch

# Define colormap and class info
CLASS_NAMES = {
    1: 'Pancreas',
    2: 'Tumor',
    3: 'Arteries',
    4: 'Veins'
}
CMAP = ListedColormap([v[0] for v in LABEL_COLORS.values()])
BOUNDARIES = [0.5, 1.5, 2.5, 3.5, 4.5]
NORM = BoundaryNorm(BOUNDARIES, CMAP.N, clip=True)

def create_2d_animation(predictions, ground_truth, patient_id, output_dir, volume=None, alpha=0.6, filename=None):
    pred_np = torch.argmax(predictions[0], dim=0).cpu().numpy()
    gt_np = ground_truth[0].cpu().numpy()
    volume_np = volume[0].cpu().numpy() if volume is not None else None
    D, H, W = pred_np.shape

    class_names = CLASS_NAMES

    # Setup figure - ajustado para dejar espacio para la leyenda
    fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=150)
    
    # Use CT volume or black background
    background = volume_np[0] if volume_np is not None else np.zeros((H, W))
    
    # Initial slice
    gt_img = axes[0].imshow(background, cmap='gray')
    gt_overlay = axes[0].imshow(
        np.where(gt_np[0] > 0, gt_np[0], np.nan),
        cmap=CMAP,
        norm=NORM,
        alpha=alpha
    )
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    pred_img = axes[1].imshow(background, cmap='gray')
    pred_overlay = axes[1].imshow(
        np.where(pred_np[0] > 0, pred_np[0], np.nan),
        cmap=CMAP,
        norm=NORM,
        alpha=alpha
    )
    axes[1].set_title('Prediction')
    axes[1].axis('off')

    title = fig.suptitle(f"Patient {patient_id} - Slice 0", fontsize=14)

    # Crear parches de colores para la leyenda
    legend_elements = []
    for class_idx, class_name in class_names.items():
        # Obtener el color asignado para esta clase desde el mapa de colores
        rgba = CMAP(NORM(class_idx))
        # Crear un parche de color
        patch = mpatches.Patch(color=rgba, label=class_name)
        legend_elements.append(patch)

    # Agregar la leyenda a la figura
    fig.legend(
        handles=legend_elements,
        loc='lower center',
        ncol=min(len(class_names), 4),  # Distribuir en filas si hay muchas clases
        bbox_to_anchor=(0.5, 0.02),
        frameon=True,
        fancybox=True,
        shadow=True
    )

    # Ajustar la disposición para dejar espacio para la leyenda
    plt.tight_layout(rect=[0, 0.1, 1, 0.95])

    def update(frame):
        bg = volume_np[frame] if volume_np is not None else np.zeros((H, W))
        gt_img.set_array(bg)
        pred_img.set_array(bg)

        gt_overlay.set_array(np.where(gt_np[frame] > 0, gt_np[frame], np.nan))
        pred_overlay.set_array(np.where(pred_np[frame] > 0, pred_np[frame], np.nan))

        title.set_text(f"Patient {patient_id} - Slice {frame}")
        return [gt_img, gt_overlay, pred_img, pred_overlay, title]

    anim = animation.FuncAnimation(fig, update, frames=D, interval=200, blit=False)

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save animation
    if filename is None:
        filename = f"patient_{patient_id}_2d_animation.gif"
    output_path = os.path.join(output_dir, filename)
    
    # Use PillowWriter directly
    writer = animation.PillowWriter(fps=5)
    anim.save(output_path, writer=writer)
    
    plt.close(fig)  # Close the figure to free memory
    
    print(f"2D animation saved to {output_path}")
    return output_path

In [68]:
output_path = create_2d_animation(
        predictions,
        masks,
        'rtum79',
        output_dir='test_predicter',
        volume=None,
        filename='test_legend.gif',
        alpha=0.8
    )

2D animation saved to test_predicter\test_legend.gif
