# Creation of a new loss function taking in consideration edges for not merging between sheets


* we want that each sheet (line) doesn't touch the other

In [None]:
import os
from os import listdir
from os.path import join

import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np

from scipy import ndimage

def load_nii(file_path, integer):
    """
    Loads a NIfTI file and returns the data array and the affine matrix.
    
    Args:
        file_path (str): Path to the .nii or .nii.gz file.
        
    Returns:
        tuple: (data_array (np.ndarray), affine_matrix (np.ndarray))
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f" The file at {file_path} was not found.")
    
    # Load the NIfTI object
    nii_img = nib.load(file_path)
    
    # Get the data as a numpy array
    # usage of get_fdata() is preferred as it automatically handles data typing
    data = nii_img.get_fdata()
    if integer:
        data = np.rint(data).astype(np.int16)
    # Get the affine (position/orientation in space)
    affine = nii_img.affine
    
    print(f"Loaded: {os.path.basename(file_path)}")
    print(f"Shape: {data.shape}")
    
    return data, affine

def save_nii_with_metadata(data_to_save, original_nii_file_path, output_filepath):
    """
    Saves a NumPy array as a NIfTI file, inheriting the affine matrix and 
    header information from an original NIfTI file.
    
    Args:
        data_to_save (np.ndarray): The 3D NumPy array containing the segmented mask.
        original_nii_file_path (str): The file path to the original NIfTI file.
        output_filepath (str): The desired path and filename for the new NIfTI file.
        
    Returns:
        str: The path to the newly saved file.
    """
    if not os.path.exists(original_nii_file_path):
        raise FileNotFoundError(f"Original file not found: {original_nii_file_path}")

    # 1. Load the original image object
    # We only need its header and affine, not the data itself
    original_img = nib.load(original_nii_file_path)
    
    # Check if the shapes match (important validation)
    # 2. Extract Affine and Header
    affine_matrix = original_img.affine
    header_data = original_img.header

    # 3. Create the new NIfTI image object
    # Use the segmented data array and the affine matrix from the original file
    try:
        new_img = nib.Nifti1Image(data_to_save, affine_matrix, header_data)
    except: 
         new_img = nib.Nifti1Image(data_to_save, affine_matrix)
    
    # 4. Update Header Datatype (Optional but recommended for masks)
    # Masks are usually integers (e.g., int8 or int16)
    # This prevents the mask from being saved unnecessarily as a float64.
    new_img.set_data_dtype(np.int16) 
    
    # 5. Save the file
    nib.save(new_img, output_filepath)
    
    print(f"Successfully saved mask to: {output_filepath}")
    return output_filepath

def save_nii(data_to_save, output_filepath):
    """
    Saves a NumPy array as a NIfTI file.
    
    Args:
        data_to_save (np.ndarray): The 3D NumPy array containing the segmented mask.
        output_filepath (str): The desired path and filename for the new NIfTI file.
        
    Returns:
        str: The path to the newly saved file.
    """

    affine_matrix = np.eye(4)  # Identity matrix as a default affine
    # Create the new NIfTI image object
    # Use the segmented data array and the affine matrix from the original file

    new_img = nib.Nifti1Image(data_to_save, affine_matrix)
    
    # Update Header Datatype (Optional but recommended for masks)
    # Masks are usually integers (e.g., int8 or int16)
    # This prevents the mask from being saved unnecessarily as a float64.
    new_img.set_data_dtype(np.int16) 
    
    # Save the file
    nib.save(new_img, output_filepath)
    
    print(f"Successfully saved mask to: {output_filepath}")
    return output_filepath

* making sure the lines are being separated

In [None]:
# We need to consider each "line" as individual labels
from skimage.measure import label
data, affine = load_nii(
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz",
    integer=True) 


print(np.unique(data))

print(f"data dtype: {data.dtype}, shape: {data.shape}")
save_nii_with_metadata(
    data, 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/original_labels.nii.gz",
    )

# ignore labels 2 

data[data==2] = 0

save_nii_with_metadata(
    data, 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/binary_labels.nii.gz")

labeled_mask = label(data)

save_nii_with_metadata(
    labeled_mask, 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/several_labels.nii.gz")

[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 2290837.nii.gz
Shape: (314, 314, 320)
[0 1 2]
data dtype: int16, shape: (314, 314, 320)
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/original_labels.nii.gz
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/binary_labels.nii.gz
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/several_labels.nii.gz


'/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/several_labels.nii.gz'

In [43]:
import numpy as np
from torch.nn.functional import sigmoid, binary_cross_entropy
from scipy.ndimage import gaussian_filter, binary_dilation, generate_binary_structure
from skimage.measure import label


sigma=1.0
w0=10.0
# getting an example
gt, affine = load_nii(
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz",
    integer=True) 

gt = np.array(gt)
gt[gt == 2] = 0 

# Label 3D Instances
# This correctly identifies connected blobs in 3D space
instance_mask = label(gt)

weight_map = np.zeros_like(instance_mask, dtype=np.float32)

# Define 3D Structure for Dilation
# A 3x3x3 cube connectivity (1 means consider diagonals, 0 means cross only)
# This replaces the 2D cv2 kernel
struct = generate_binary_structure(3, 2) 

# 4. Iterate Instances
obj_ids = np.unique(instance_mask)
obj_ids = obj_ids[obj_ids != 0]

for obj_id in obj_ids:
    obj_mask = (instance_mask == obj_id)

    # 3D Dilation
    # Expands the object in X, Y, AND Z directions
    dilated = binary_dilation(obj_mask, structure=struct, iterations=2)
    
    # Create Halo
    halo = dilated ^ obj_mask # XOR operation (same as dilated - original)
    
    # 3D Gaussian Blur
    # Scipy handles 3D arrays automatically here
    blurred_halo = gaussian_filter(halo.astype(float), sigma=sigma)
    
    # Normalize and Add
    if blurred_halo.max() > 0:
        blurred_halo = blurred_halo / blurred_halo.max() * w0
        
    weight_map += blurred_halo

# Final Base Weight
weight_map += 1.0

save_nii_with_metadata(
    weight_map, 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/weight_map.nii.gz")

[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 2290837.nii.gz
Shape: (314, 314, 320)
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/weight_map.nii.gz


'/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/weight_map.nii.gz'

### Create the function to create the weight map

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import gaussian_filter, binary_dilation, generate_binary_structure
from skimage.measure import label

class AntiBridgeLoss(nn.Module):
    def __init__(self, sigma=1.0, w0=10.0):
        super().__init__()
        self.smooth = 1e-5
        self.sigma = sigma
        self.w0 = w0

    def _get_edge_weight_map(self, gt_vol):
        """
        Generates a 3D weight map for one volume.
        gt_vol input: Numpy array (D, H, W)
        Returns: Numpy array (1, 1, D, H, W)
        """
        # Ensure clean copy to avoid side-effects on the original dataset
        gt_vol = np.array(gt_vol)
        gt_vol[gt_vol == 2] = 0 
        
        # Label 3D Instances
        instance_mask = label(gt_vol)
        
        weight_map = np.zeros_like(instance_mask, dtype=np.float32)

        # 3D Structure for Dilation (Connect across Z axis too)
        struct = generate_binary_structure(3, 2) 

        obj_ids = np.unique(instance_mask)
        obj_ids = obj_ids[obj_ids != 0]

        for obj_id in obj_ids:
            obj_mask = (instance_mask == obj_id)

            # 3D Dilation
            dilated = binary_dilation(obj_mask, structure=struct, iterations=2)
            
            # Create Halo (XOR)
            halo = dilated ^ obj_mask 
            
            # 3D Gaussian Blur
            blurred_halo = gaussian_filter(halo.astype(float), sigma=self.sigma)
            
            # Normalize and Add
            if blurred_halo.max() > 0:
                blurred_halo = (blurred_halo / blurred_halo.max()) * self.w0
                
            weight_map += blurred_halo

        # Final Base Weight
        weight_map += 1.0
        
        # Add Batch and Channel dimensions: (D,H,W) -> (1, 1, D, H, W)
        return weight_map[None, None, ...]

    def forward(self, pred, gt, roi_mask, bridge_weight_map):
        """
        pred: (B, 1, D, H, W) Logits
        gt:   (B, 1, D, H, W) Binary Target
        roi_mask: (B, 1, D, H, W) 
        bridge_weight_map: (Optional) Pre-calculated map. 
        """
        
        # Activation (Logits -> Probabilities)
        pred_prob = torch.sigmoid(pred)

        # Handle Weight Map Generation (CPU <-> GPU Bridge)
        if bridge_weight_map is None:
            # WARNING: Generating this on the fly is SLOW. 
            # It moves tensors to CPU, runs Scipy, and moves back.
            
            generated_maps = []
            # Loop over batch (B)
            for i in range(gt.shape[0]):
                # Detach, move to CPU, convert to Numpy, remove Channel dim
                gt_np = gt[i, 0].detach().cpu().numpy()
                
                # Generate Map
                w_map_np = self._get_edge_weight_map(gt_np)
                
                # Convert back to Tensor and move to correct device
                w_map_tensor = torch.from_numpy(w_map_np).to(pred.device)
                generated_maps.append(w_map_tensor)
            
            # Stack batch back together
            bridge_weight_map = torch.cat(generated_maps, dim=0)


        # Compute Weighted Loss
        # reduction='none' is crucial to keep the shape for masking
        pixel_loss = F.binary_cross_entropy(pred_prob, gt, reduction='none')
        
        # Apply Anti-Bridge Weights
        weighted_loss = pixel_loss * bridge_weight_map

        # Apply ROI Mask
        # Zero out loss in ignored regions
        masked_loss = weighted_loss * roi_mask

        # Average over VALID pixels only
        loss_scalar = masked_loss.sum() / (roi_mask.sum() + self.smooth)

        return loss_scalar

In [10]:
# TEST

anti_bridge_criterio = AntiBridgeLoss(sigma=1.0, w0=10.0)

gt, affine = load_nii(
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz",
    integer=True) 

weight_map = anti_bridge_criterio._get_edge_weight_map(gt)
save_nii_with_metadata(
    weight_map[0][0], 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/weight_map.nii.gz")

FileNotFoundError:  The file at /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz was not found.

In [5]:
def test_antibridge_loss():
    print("--- STARTING ANTI-BRIDGE LOSS TEST ---\n")
    
    # 1. Setup Mock Data (Small 3D Volume: 1x1x10x10x10)
    # We will place two objects close to each other
    D, H, W = 10, 10, 10
    gt_vol = np.zeros((D, H, W), dtype=np.float32)
    
    # Object 1: Cube at left
    gt_vol[2:8, 2:8, 1:4] = 1 
    # Object 2: Cube at right (Gap of 2 pixels in W dimension)
    gt_vol[2:8, 2:8, 6:9] = 1 
    
    # Convert to Tensors (Batch=1, Channel=1)
    gt_tensor = torch.from_numpy(gt_vol).unsqueeze(0).unsqueeze(0).float()
    
    # ROI Mask: Fully valid for now
    roi_mask = torch.ones_like(gt_tensor)

    # Initialize Loss
    w0 = 10.0
    loss_fn = AntiBridgeLoss(sigma=1.0, w0=w0)

    # ---------------------------------------------------------
    # TEST 1: Weight Map Generation
    # ---------------------------------------------------------
    print("Test 1: checking weight map generation...")
    # Generate map manually to inspect values
    weight_map_np = loss_fn._get_edge_weight_map(gt_vol)
    weight_map = torch.from_numpy(weight_map_np)
    
    # Inspect the GAP pixel (Slice 5, Row 5, Col 5)
    # Obj1 ends at Col 3, Obj2 starts at Col 6. Gap is Col 4, 5.
    gap_val = weight_map[0, 0, 5, 5, 5].item()
    center_val = weight_map[0, 0, 5, 5, 2].item() # Inside Obj 1
    
    print(f"  - Weight inside object: {center_val:.4f} (Expected ~1.0)")
    print(f"  - Weight in the gap:    {gap_val:.4f} (Expected > 1.0, close to {1.0 + w0*2:.1f} due to overlap)")
    
    if gap_val > center_val:
        print("  ✅ PASS: Gap has higher weight than object center.")
    else:
        print("  ❌ FAIL: Gap weight is not higher.")

    # ---------------------------------------------------------
    # TEST 2: Anti-Bridge Logic (The Penalty)
    # ---------------------------------------------------------
    print("\nTest 2: Checking bridging penalty...")
    
    # Scenario A: Perfect Prediction
    # Logits: High positive for FG, High negative for BG
    pred_perfect = torch.zeros_like(gt_tensor) + (-10.0) # Background everywhere
    pred_perfect[gt_tensor==1] = 10.0 # Foreground where GT is
    
    # Scenario B: Bridging Prediction (Bad!)
    # We predict the GAP pixels (4 and 5) as Foreground
    pred_bridge = pred_perfect.clone()
    pred_bridge[:, :, 2:8, 2:8, 4:6] = 10.0 # Fill the gap
    print(f"gt_tensor.numpy(): {gt_tensor.numpy().shape}")
    save_nii(gt_tensor.numpy()[0][0], output_filepath="/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/gt_tensor.nii.gz")
    save_nii(pred_perfect.numpy()[0][0], output_filepath="/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_perfect.nii.gz")
    save_nii(pred_bridge.numpy()[0][0], output_filepath="/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_bridge.nii.gz")
    
    pred_no_bridge = pred_perfect.clone()
    pred_no_bridge[:, :, 2:8, 2:8, 6:9] = -10.0 # Fill the gap
    save_nii(pred_no_bridge.numpy()[0][0], output_filepath="/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_no_bridge.nii.gz")

    pred_bridge_incomplete = pred_no_bridge.clone()
    pred_bridge_incomplete[:, :, 2:8, 2:8, 4:6] = 10.0 # Fill the gap
    # Calculate Losses
    # Note: We pass None for map so it generates it internally
    loss_perfect = loss_fn(pred_perfect, gt_tensor, roi_mask, bridge_weight_map=None).item()
    loss_bridge = loss_fn(pred_bridge, gt_tensor, roi_mask, bridge_weight_map=None).item()
    loss_pred = loss_fn(pred_no_bridge, gt_tensor, roi_mask, bridge_weight_map=None).item()
    loss_pred_incomplete = loss_fn(pred_bridge_incomplete, gt_tensor, roi_mask, bridge_weight_map=None).item()
    
    print(f"  - Loss (Perfect): {loss_perfect:.6f}")
    print(f"  - Loss (Bridged): {loss_bridge:.6f}")
    print(f"  - Loss (pred): {loss_pred:.6f}")
    print(f"  - Loss (pred_incomplete): {loss_pred_incomplete:.6f}")

    if loss_bridge > loss_perfect * 10: # Should be significantly higher
        print("  ✅ PASS: Bridging is heavily penalized.")
    else:
        print("  ❌ FAIL: Bridging penalty is too low.")

    # ---------------------------------------------------------
    # TEST 3: ROI Masking Logic
    # ---------------------------------------------------------
    print("\nTest 3: Checking ROI masking...")
    
    # We create a prediction that is TERRIBLE (All ones), 
    # but we mask the entire volume so loss should be 0.
    pred_terrible = torch.ones_like(gt_tensor) * 10.0 
    roi_mask_zero = torch.zeros_like(gt_tensor) # Ignore everything
    
    loss_masked = loss_fn(pred_terrible, gt_tensor, roi_mask_zero, bridge_weight_map=None).item()
    
    print(f"  - Loss (Fully Masked): {loss_masked:.6f}")
    
    if loss_masked < 1e-4:
        print("  ✅ PASS: ROI Mask successfully ignored the terrible prediction.")
    else:
        print("  ❌ FAIL: ROI Mask did not ignore the errors.")

    print("\n--- TEST COMPLETE ---")

# Run the test
if __name__ == "__main__":
    test_antibridge_loss()

--- STARTING ANTI-BRIDGE LOSS TEST ---

Test 1: checking weight map generation...
  - Weight inside object: 2.8780 (Expected ~1.0)
  - Weight in the gap:    14.1506 (Expected > 1.0, close to 21.0 due to overlap)
  ✅ PASS: Gap has higher weight than object center.

Test 2: Checking bridging penalty...
gt_tensor.numpy(): (1, 1, 10, 10, 10)
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/gt_tensor.nii.gz
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_perfect.nii.gz
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_bridge.nii.gz
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/trash/pred_no_bridge.nii.gz
  - Loss (Perfect): 0.000484
  - Loss (Bridged): 10.741927
  - Loss (pred): 6.954436
  - Loss (pred_incomplete): 17.695879
  ✅ PASS: Bridging is heavily penalized.

Test 3: Checking ROI masking...
  - L

### Load the class for pre-computing the bridge-weights


In [None]:
import sys
sys.path.append(os.path.abspath("../utils"))
from AntiBridgeLoss import AntiBridgeLoss
anti_bridge_obj = AntiBridgeLoss(sigma=1.0, w0=10.0)

label_data_path = "/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop"
for case_name in os.listdir(label_data_path):
    case_path = os.path.join(label_data_path, case_name)
    # Load the volume from the NIfTI file
    volume, affine = load_nii(case_path, integer=True)
    weight_map = anti_bridge_obj._get_edge_weight_map(gt_vol=volume)

    save_nii_with_metadata(
        data_to_save=weight_map[0][0], 
        original_nii_file_path=case_path,
        output_filepath=f"/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_bridge_weight_map_crop/{case_name}"
    )
    

[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 2669341205.nii.gz
Shape: (314, 297, 320)
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_bridge_weight_map_crop/2669341205.nii.gz


In [14]:
import os
import sys
from concurrent.futures import ProcessPoolExecutor
from functools import partial

# Ensure utils path is available
sys.path.append(os.path.abspath("../utils"))
from AntiBridgeLoss import AntiBridgeLoss

def process_single_case(case_name, label_data_path, output_dir, anti_bridge_obj):
    """Function to process a single NIfTI file."""
    case_path = os.path.join(label_data_path, case_name)
    output_path = os.path.join(output_dir, case_name)
    
    # Skip if already processed (optional optimization)
    if os.path.exists(output_path):
        return f"Skipped: {case_name}"

    try:
        # Load the volume from the NIfTI file
        # Note: Ensure load_nii and save_nii_with_metadata are imported or defined
        volume, affine = load_nii(case_path, integer=True)
        
        # Generate the weight map
        weight_map = anti_bridge_obj._get_edge_weight_map(gt_vol=volume)

        # Save the result
        save_nii_with_metadata(
            data_to_save=weight_map[0][0], 
            original_nii_file_path=case_path,
            output_filepath=output_path
        )
        return f"Processed: {case_name}"
    except Exception as e:
        return f"Error processing {case_name}: {e}"

def main():
    label_data_path = "/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop"
    output_dir = "/home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_bridge_weight_map_crop"
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize the object
    anti_bridge_obj = AntiBridgeLoss(sigma=1.0, w0=10.0)
    
    # Get list of cases
    case_names = os.listdir(label_data_path)
    
    # Use ProcessPoolExecutor to run on 8 cores
    print(f"Starting parallel processing with 8 cores for {len(case_names)} files...")
    
    # Use partial to fix the constant arguments
    worker_func = partial(
        process_single_case, 
        label_data_path=label_data_path, 
        output_dir=output_dir, 
        anti_bridge_obj=anti_bridge_obj
    )

    with ProcessPoolExecutor(max_workers=8) as executor:
        results = list(executor.map(worker_func, case_names))

    for res in results:
        print(res)

if __name__ == "__main__":
    main()

Starting parallel processing with 8 cores for 786 files...
[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 4024699648.nii.gz
Shape: (250, 250, 256)
[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 1531506078.nii.gz
Shape: (250, 250, 256)
[0.         1.00001526 2.        ]
[0.         1.00001526 2.        ]
[0.         1.00001526 2.        ]
[0.         1.00001526 2.        ]
[0.         1.00001526 2.        ]
[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 2555675774.nii.gz
Shape: (314, 314, 320)
[0 1 2]
Loaded: 3918162598.nii.gz
Shape: (314, 314, 320)
[0 1 2]
Loaded: 516292750.nii.gz
Shape: (314, 314, 320)
[0 1 2]
Loaded: 797181487.nii.gz[0 1 2]

Shape: (314, 314, 320)Loaded: 4213889683.nii.gz

Shape: (314, 314, 320)
[0 1 2]
Loaded: 572077733.nii.gz
Shape: (314, 314, 320)
Successfully saved mask to: /home/shadowtwin/Desktop/AI_work/Vesuvius_Challenge/Vesuvius/DataSet/Challenge_dataset_updated/train_bridge_weight_map_crop/1531506078.nii.gz
[0.         1.00001526 2.        ]
[0 

## Create a function to avoid descontuinity (*clDice*) -> from these tests, this doesn't work 

In [2]:
import sys
import os

sys.path.append(os.path.abspath("/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/utils"))

from cldice.cldice import soft_cldice



In [3]:
import torch
import torch.nn as nn
import unittest

# --- Paste your provided classes here (SoftSkeletonize, soft_cldice, etc.) ---
# (Assuming the code you provided is available in the scope or imported)

class TestVesuviusLoss(unittest.TestCase):
    def setUp(self):
        # Vesuvius-like dimensions: Batch=2, Channels=1, Depth=16, H=64, W=64
        self.b, self.c, self.d, self.h, self.w = 2, 1, 16, 64, 64
        self.shape_3d = (self.b, self.c, self.d, self.h, self.w)
        self.shape_2d = (self.b, self.c, self.h, self.w)
        
        # Instantiate the loss
        self.loss_fn = soft_cldice(iter_=3, smooth=1e-5, exclude_background=False)

    def test_forward_pass_3d(self):
        """Does it run without crashing on 3D volumes?"""
        print("\nTesting 3D Forward Pass...")
        pred = torch.rand(self.shape_3d, requires_grad=True)
        target = torch.randint(0, 2, self.shape_3d).float()
        
        loss, skeleton_pred, skel_true = self.loss_fn(target, pred)
        print(f"3D Loss Value: {loss.item()}")
        self.assertFalse(torch.isnan(loss), "Loss should not be NaN")

    def test_gradient_flow(self):
        """Is the loss differentiable?"""
        print("\nTesting Gradient Flow...")
        pred = torch.rand(self.shape_3d, requires_grad=True)
        target = torch.randint(0, 2, self.shape_3d).float()
        
        loss, skeleton_pred, skel_true = self.loss_fn(target, pred)
        loss.backward()
        
        # Check if gradients exist and are not zero
        self.assertIsNotNone(pred.grad, "Gradients should be calculated")
        # We expect some gradients to be non-zero
        self.assertTrue(torch.sum(torch.abs(pred.grad)) > 0, "Gradients should be non-zero")
        print("Gradients computed successfully.")

    def test_perfect_match(self):
        """Does perfect prediction yield near-zero loss?"""
        print("\nTesting Perfect Match...")
        # Create a synthetic 'line' target to guarantee a skeleton exists
        target = torch.zeros(self.shape_3d)
        target[:, :, :, 30:34, 30:34] = 1.0 # A central 'tube'
        
        # Prediction equals target
        pred = target.clone().detach().requires_grad_(True)
        
        loss, skeleton_pred, skel_true = self.loss_fn(target, pred)
        print(f"Perfect Match Loss: {loss.item()}")
        # clDice isn't always exactly 0.0 due to soft operations, but should be very low
        self.assertLess(loss.item(), 0.1, "Perfect match should have low loss")

    def test_empty_prediction(self):
        """Does empty prediction yield high loss?"""
        print("\nTesting Empty Prediction...")
        target = torch.zeros(self.shape_3d)
        target[:, :, :, 30:34, 30:34] = 1.0 # A central 'tube'
        
        # Prediction is all zeros (empty)
        pred = torch.zeros_like(target, requires_grad=True)
        
        loss, skeleton_pred, skel_true = self.loss_fn(target, pred)
        print(f"Empty Prediction Loss: {loss.item()}")
        self.assertGreater(loss.item(), 0.5, "Empty prediction should have high loss")

    def test_iter_parameter_bug(self):
        """
        Check if the 'iter_' parameter is actually being used.
        Based on your code snippet, it looked like it was ignored.
        """
        print("\nChecking 'iter_' parameter usage...")
        
        # Create a loss with 1 iteration
        loss_short = soft_cldice(iter_=1)
        # Create a loss with 20 iterations
        loss_long = soft_cldice(iter_=20)
        
        # IN YOUR PROVIDED CODE:
        # self.soft_skeletonize = SoftSkeletonize(num_iter=10) 
        # This is hardcoded! Both losses will behave IDENTICALLY.
        
        if loss_short.soft_skeletonize.num_iter == loss_long.soft_skeletonize.num_iter:
            print("WARNING: Your code is ignoring the 'iter_' parameter!")
            print(f"Both are fixed to: {loss_short.soft_skeletonize.num_iter}")
            # This is not a 'test failure' per se, but a warning for you.
        else:
            print("Good: 'iter_' parameter is being passed correctly.")
    
# --- Run the tests ---
if __name__ == '__main__':
    # Standard unittest runner
    suite = unittest.TestLoader().loadTestsFromTestCase(TestVesuviusLoss)
    unittest.TextTestRunner(verbosity=2).run(suite)

test_empty_prediction (__main__.TestVesuviusLoss.test_empty_prediction)
Does empty prediction yield high loss? ... ok
test_forward_pass_3d (__main__.TestVesuviusLoss.test_forward_pass_3d)
Does it run without crashing on 3D volumes? ... 


Testing Empty Prediction...
Empty Prediction Loss: 0.9999998211860657

Testing 3D Forward Pass...


ok
test_gradient_flow (__main__.TestVesuviusLoss.test_gradient_flow)
Is the loss differentiable? ... ok
test_iter_parameter_bug (__main__.TestVesuviusLoss.test_iter_parameter_bug)
Check if the 'iter_' parameter is actually being used. ... ok
test_perfect_match (__main__.TestVesuviusLoss.test_perfect_match)
Does perfect prediction yield near-zero loss? ... 

3D Loss Value: 0.49892014265060425

Testing Gradient Flow...
Gradients computed successfully.

Checking 'iter_' parameter usage...
Good: 'iter_' parameter is being passed correctly.

Testing Perfect Match...


ok

----------------------------------------------------------------------
Ran 5 tests in 0.495s

OK


Perfect Match Loss: 0.0


In [4]:
def soft_dice(y_true, y_pred):
    """
    Standard Soft Dice Loss.
    Returns: 1 - DiceScore
    """
    smooth = 1e-5
    intersection = torch.sum((y_true * y_pred))
    coeff = (2. * intersection + smooth) / (torch.sum(y_true) + torch.sum(y_pred) + smooth)
    return (1. - coeff)

In [27]:
# 4. Modified Test Function
def test_dumbbell_bridge():
    print("\n=== Testing The 'Dumbbell' Scenario ===")
    
    # Setup Canvas
    shape = (1, 1, 64, 64, 64)
    ground_truth = torch.zeros(shape)
    
    # Create Ground Truth: Two Massive Blocks connected by a Thin Wire
    ground_truth[:, :, 20:44, 20:44, 10:20] = 1.0 # Left Block
    ground_truth[:, :, 20:44, 20:44, 44:54] = 1.0 # Right Block
    ground_truth[:, :, 32, 32, 20:44] = 1.0       # Bridge
    
    # Create Prediction: Same Blocks, but BROKEN Bridge
    prediction = ground_truth.clone()
    prediction[:, :, 32, 32, 30:34] = 0.0         # Cut the bridge
    
    prediction.requires_grad = True
    
    # Initialize Loss
    criterion_cldice = soft_cldice(iter_=3, smooth=1e-5)
    
    # Calculate Scores
    loss_dice = soft_dice(ground_truth, prediction)
    
    # Unpack the 3 return values
    loss_cldice, skeleton_pred, skel_true = criterion_cldice(ground_truth, prediction)
    
    # Report
    print(f"Scenario: Two huge blocks connected by a tiny wire.")
    print(f"Defect:   The wire is snapped.")
    print(f"-"*40)
    print(f"Dice Loss:    {loss_dice.item():.4f}")
    print(f"clDice Loss:  {loss_cldice.item():.4f}")
    print(f"-"*40)
    
    if loss_cldice > loss_dice * 2.0:
        print("PASS: clDice is significantly more sensitive to the break!")
    else:
        print("FAIL: Scores are still too similar.")

    # Return the 4 tensors you need for saving
    return ground_truth, prediction, skeleton_pred, skel_true

# --- Execute and unpack ---
ground_truth, prediction, skeleton_pred, skel_true = test_dumbbell_bridge()


=== Testing The 'Dumbbell' Scenario ===
Scenario: Two huge blocks connected by a tiny wire.
Defect:   The wire is snapped.
----------------------------------------
Dice Loss:    0.0002
clDice Loss:  0.1000
----------------------------------------
PASS: clDice is significantly more sensitive to the break!


In [31]:
def generate_thick_helix(shape, radius=15, turns=1.5, thickness=2):
    """
    Generates a 3D spiral tube.
    Returns: Tensor of shape (1, 1, D, H, W)
    """
    volume = torch.zeros(shape)
    D, H, W = shape[2], shape[3], shape[4]
    
    steps = 600 # Resolution of the spiral
    t = np.linspace(0, turns * 2 * np.pi, steps)
    
    # Helix Math: Z moves linearly, X/Y move in a circle
    z_coords = np.linspace(10, D-10, steps)
    y_coords = (H // 2) + radius * np.sin(t)
    x_coords = (W // 2) + radius * np.cos(t)
    
    # Draw the spiral
    for i in range(steps):
        z, y, x = int(z_coords[i]), int(y_coords[i]), int(x_coords[i])
        
        # Define the cube bounds for "thickness"
        z0, z1 = max(0, z-thickness), min(D, z+thickness+1)
        y0, y1 = max(0, y-thickness), min(H, y+thickness+1)
        x0, x1 = max(0, x-thickness), min(W, x+thickness+1)
        
        volume[0, 0, z0:z1, y0:y1, x0:x1] = 1.0
        
    return volume

# ==========================================
# 3. Test Function
# ==========================================

def test_broken_helix():
    print("\n=== Testing Complex Structure: The Broken Helix ===")
    
    # 1. Setup Canvas (64^3 cube)
    shape = (1, 1, 64, 64, 64)
    
    # 2. Generate Ground Truth (Complete Spiral)
    # Thickness=2 means the tube is about 5x5 pixels wide
    ground_truth = generate_thick_helix(shape, radius=18, turns=1.5, thickness=2)
    
    # 3. Generate Prediction (Spiral with a Break)
    prediction = ground_truth.clone()
    
    # Create a break in the middle of the Z-stack
    # Cutting slices 30 to 34 creates a distinct gap
    prediction[:, :, 30:34, :, :] = 0.0
    
    # Ensure gradients and range [0,1]
    prediction.requires_grad = True
    ground_truth = torch.clamp(ground_truth, 0.0, 1.0)
    prediction = torch.clamp(prediction, 0.0, 1.0)
    
    # 4. Initialize Losses
    # iter_=3 is ideal for thickness=2 (diameter ~5)
    criterion_cldice = soft_cldice(iter_=30, smooth=1e-5)
    
    # 5. Calculate Scores
    loss_dice = soft_dice(ground_truth, prediction)
    loss_cldice, skeleton_pred, skel_true = criterion_cldice(ground_truth, prediction)
    
    # 6. Report
    print(f"Structure: 3D Spiral (Radius=18, Turns=1.5)")
    print(f"Defect:    Gap of 4 slices in the middle")
    print(f"-"*40)
    print(f"Dice Loss:    {loss_dice.item():.4f}")
    print(f"clDice Loss:  {loss_cldice.item():.4f}")
    print(f"-"*40)
    
    if loss_cldice > loss_dice * 1.5:
        print("PASS: clDice detects the break in the complex spiral!")
    else:
        print("FAIL: Scores are too similar.")
        
    return ground_truth, prediction, skeleton_pred, skel_true

# --- Run it ---
ground_truth, prediction, skeleton_pred, skel_true = test_broken_helix()


=== Testing Complex Structure: The Broken Helix ===
Structure: 3D Spiral (Radius=18, Turns=1.5)
Defect:    Gap of 4 slices in the middle
----------------------------------------
Dice Loss:    0.0449
clDice Loss:  0.0621
----------------------------------------
FAIL: Scores are too similar.


In [57]:
def test_broken_helix():
    print("\n=== Testing Complex Structure: The Broken Helix ===")
    
    
    # 2. Generate Ground Truth (Complete Spiral)
    # Thickness=2 means the tube is about 5x5 pixels wide
    ground_truth, affine = load_nii(
        "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz",
        integer=True) 
    ground_truth[ground_truth==2] = 0 
    ground_truth = torch.from_numpy(ground_truth)
    ground_truth = ground_truth.unsqueeze(0).unsqueeze(0).float()

    prediction = ground_truth.clone()

    # SEVER TEST: Cut a 5-slice gap across the ENTIRE Z-axis
    # This guarantees the connectivity is broken regardless of X/Y shape
    mid_z = ground_truth.shape[2] // 2
    prediction[:, :, mid_z : mid_z + 5, :, :] = 1
    
    # Use iter_=1 first to ensure we don't over-erode thin papyrus
    criterion_cldice = soft_cldice(iter_=10, smooth=1e-5)
    
    loss_dice = soft_dice(ground_truth, prediction)
    loss_cldice, skeleton_pred, skel_true = criterion_cldice(ground_truth, prediction)
    
    print(f"Dice Loss:    {loss_dice.item():.6f}")
    print(f"clDice Loss:  {loss_cldice.item():.6f}")
    
    return ground_truth, prediction, skeleton_pred, skel_true

# --- Run it ---
ground_truth, prediction, skeleton_pred, skel_true = test_broken_helix()


=== Testing Complex Structure: The Broken Helix ===
[0.         1.00001526 2.        ]
[0 1 2]
Loaded: 2290837.nii.gz
Shape: (314, 314, 320)
Dice Loss:    0.129822
clDice Loss:  0.071428


In [59]:
save_nii_with_metadata(
    ground_truth[0][0].detach(), 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/ground_truth.nii.gz")

save_nii_with_metadata(
    data_to_save=prediction[0][0].detach(), 
    original_nii_file_path="/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    output_filepath="/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/prediction.nii.gz")

save_nii_with_metadata(
    skeleton_pred[0][0].detach(), 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/skeleton_pred.nii.gz")

save_nii_with_metadata(
    skel_true[0][0].detach(), 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset_updated/train_labels_nii_crop/2290837.nii.gz", 
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/skel_true.nii.gz")

Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/ground_truth.nii.gz
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/prediction.nii.gz
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/skeleton_pred.nii.gz
Successfully saved mask to: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/skel_true.nii.gz


'/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/skel_true.nii.gz'