# We will take the pre-trained STU-Net (Large -> Pre-trained on the TotalSegmentator cases) and further pre-train on the scrolls and fragments 

In [1]:
LEARNING_RATE = 0.0001
NUM_EPOCHS = 1000
PATCH_SIZE = (128,128,128)
CRITERION = "L1"
MASK_WEIGHT = 10
LR_SCHEDULER = "CosineAnnealingLR"
DEVICE = 'cuda'
SAVE_DIR = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/5_scrolls_pretrain"
CHECKPOINT_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/pre-trained/Independent/binary_large_ep4k.pth"
DATA_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training"
VAL_DATA_PATH = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset"

BATCH_SIZE = 2
NUM_WORKERS = 8

RESUME = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/5_scrolls_pretrain/wandb/run-20251219_094054-w862cvd3/files/model/model_epoch_100.pth"
DEBUG = False

## Building the data loader (DONE)


In [2]:
# Standard library
from os import listdir, makedirs
from os.path import join


# Third-party libraries
import numpy as np
import zarr
import random
from tqdm import tqdm
import wandb
import nibabel as nib

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler

# MONAI
import monai
from monai.transforms import (
    Compose,
    EnsureType,
    RandCoarseDropout,
    RandCropByPosNegLabel,
    RandFlip,
    RandGaussianNoise,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
    ScaleIntensityRange,
    LoadImage,
    CopyItemsd,
    LoadImaged, 
    ScaleIntensityRanged, 
    ResizeWithPadOrCropd, 
    RandCoarseDropoutd,
    EnsureTyped,
    EnsureChannelFirstd
)
from monai.data import DataLoader, CacheDataset

from stunet_model import STUNetReconstruction

In [3]:
def save_vol(tensor, path):
    # prediction: torch.Tensor
    # shape example: [B, C, H, W] or [B, 1, H, W]

    tensor_cpu = tensor.detach().cpu()

    # Remove batch and channel dims if needed
    tensor_cpu = tensor_cpu[0]          
    tensor_cpu = tensor_cpu.squeeze(0)  

    tensor_np = tensor_cpu.numpy()

    affine = np.eye(4)  # identity affine (OK if no spatial metadata)

    nii = nib.Nifti1Image(tensor_np.astype(np.float32), affine)
    nib.save(nii, path)

In [4]:
class ZarrVolumeDataset(Dataset):
    def __init__(self, zarr_path, transform_input, transform_deform, patch_size=(128, 128, 128), threshold=0.0): # TODO define a threshold 
        """
        it expects to receive:
            zarr_path -> path to the root zarr folder
            transform_input -> MONAI transforms to load the data
            transform_deform -> MONAI transforms for self-supervised training
            patch_size -> Patch size (double check if a pre-trained network is being used)
            threshold -> Only returns the data if any voxel inside of the patch is greater than threshold.

        """
        self.zarr_path = zarr_path
        self.transform_input = transform_input
        self.transform_deform = transform_deform
        self.patch_size = patch_size
        self.threshold = threshold  # Value below which we consider the pixel "background"

        print(f"Loading data from -> {zarr_path}")

        self.zarr_vols_paths = []
        for zarr_folder in listdir(zarr_path):
            if zarr_folder.endswith(".zarr"):
                complete_zarr_path = join(zarr_path, zarr_folder)
                self.zarr_vols_paths.append(complete_zarr_path) # save a list of paths
                
                # Load the zarr file in the __getitem___
                # shape = vol.shape
                # entry = {
                #     "name": zarr_folder,
                #     "volume": vol,
                #     "shape": shape
                # }
                # self.zarr_vols.append(entry)

    def __len__(self):
        # Defining a length of one epoch
        return 1000 

    def __getitem__(self, index):
        # Transformations on the fly
        # Using lazy loader (memory doesn't handle such big data)
        random_entry = random.choice(self.zarr_vols_paths) # select random path
        # open the file (lazy)
        root = zarr.open(random_entry, mode='r')     
        if 'volume' in root:
            vol = root['volume']
        else:
            vol = root['0']
        
        shape = vol.shape
        
        z_max = max(0, shape[0] - self.patch_size[0])
        y_max = max(0, shape[1] - self.patch_size[1])
        x_max = max(0, shape[2] - self.patch_size[2])

        # --- THE REJECTION SAMPLING LOOP ---
        # Try up to 100 times to find a non-empty chunk (very likely to find one!)
        for attempt in range(100):
            # We'll do it ourselves, it's easier to understand
            # 1. Random Coordinates
            z_start = np.random.randint(0, z_max) if z_max > 0 else 0
            y_start = np.random.randint(0, y_max) if y_max > 0 else 0
            x_start = np.random.randint(0, x_max) if x_max > 0 else 0

            # 2. Load the chunk
            patch = vol[
                z_start : z_start + self.patch_size[0],
                y_start : y_start + self.patch_size[1],
                x_start : x_start + self.patch_size[2]
            ]

            # 3. Check if it contains data
            # If the max value in this patch is greater than our threshold (0), it's valid.
            if np.max(patch) > self.threshold:
                # Found valid data! Break the loop and process it.
                break
            
            # If we are here, the patch was empty. The loop continues to the next attempt.
        
        # Note: If the loop finishes 20 times and finds nothing, it will return the LAST empty patch.
        # This prevents the code from hanging forever if the file is truly empty.

        # 4. MONAI Formatting
        patch = patch.astype(np.float32) # Ensure float for transforms
        patch = patch[np.newaxis, ...]   # Add Channel dim -> (1, Z, Y, X)
        
        tracking_mask = np.ones_like(patch) # Create a volume full one 1s to track the mask generated

        # Normalization
        patch_dict = self.transform_input(
            {"image": patch}
        )

        # Save the clean image
        clean_patch = patch_dict['image'].clone()
        
        # Create masked volume
        deform_patch = self.transform_deform(
            {
                "image": patch_dict['image'], 
                "tracking_mask": tracking_mask
            }
        )

        dropout_mask = 1 - deform_patch['tracking_mask']

        return {
            'clean_patch':clean_patch,
            'deform_patch':deform_patch["image"],
            'dropout_mask': dropout_mask
        }

def describe_tensor(name, t):
    print(f"{name}:")
    print(f"  type:  {type(t)}")
    print(f"  dtype: {t.dtype}")
    print(f"  device:{t.device}")
    print(f"  shape: {tuple(t.shape)}")
    print()
        

In [5]:
if DEBUG:
    # --- Setup ---
    transform_input = Compose([
        # Load image will be handeled by the lazzy zarr loading data
        ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),
        # Not do it, we have a lot of data for pre-training
        # RandFlip(
        #    prob=0.1, 
        #    spatial_axis=None
        #    ),
        # RandRotate90(
        #    prob=0.1,  
        #    max_k=3, 
        #    spatial_axes=(0, 1), 
        #    lazy=False)
        EnsureTyped(keys=["image"])
    ])

    # The Corruption Transforms
    # We want to force the model to fix heavy defects.
    transform_deform = Compose([
        # Cut out 10 holes, each spatial size roughly 32x48x48
        # 35% of data loss
        RandCoarseDropoutd(
            keys=["image", "tracking_mask"],
            holes=10, 
            spatial_size=(32, 48, 48), 
            fill_value=0,
            prob=1.0 # Always apply
        ),
        # Add noise (not do it)
        #RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),
        EnsureTyped(keys=["image", "tracking_mask"])
    ])

    print("Initializing Dataset...")
    ds = ZarrVolumeDataset(
        "/mounts/disk2/Andre_Data_Augmentation/PhD/Vesuvius", 
        transform_input=transform_input,
        transform_deform=transform_deform,
        patch_size=(128, 128, 128) 
    )

    print("Initializing DataLoader...")
    loader = monai.data.DataLoader(ds, batch_size=1, num_workers=4)
    data_loader = iter(loader)
    first_batch = next(data_loader)
    clean_patch = first_batch['clean_patch']
    deform_patch = first_batch['deform_patch']
    print("Fetching a batch to ensure it's not empty...")

    print(f"Batch Max Value: {clean_patch.max()}")
    if clean_patch.max() == 0:
        print("WARNING: The batch is still empty. Your threshold might be too high or the volume is empty.")
    else:
        print("Success! Loaded a non-empty chunk.")

    print(f"Success! Batch shape: {clean_patch.shape}")

    describe_tensor("clean_patch", clean_patch)
    describe_tensor("deform_patch", deform_patch)
    for i in range(10):
        first_batch = next(data_loader)
        dropout_mask = first_batch['dropout_mask']

        save_vol(dropout_mask, f"/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/trash/train_mask_{i}.nii.gz")

In [6]:
if DEBUG:
    import matplotlib.pyplot as plt
    from ipywidgets import interact, IntSlider
    import numpy as np
    import time

    # 1. Extract the raw 3D volume from the batch
    # MONAI batch shape is (Batch_Size, Channel, Dim1, Dim2, Dim3)
    # We select Batch 0 and Channel 0
    input_data = deform_patch[0, 0].cpu().numpy()

    print(f"Batch Shape: {clean_patch.shape}")
    print(f"Visualizing Sample Shape: {clean_patch.shape}")

    # 2. Setup Interactive Viewer
    def view_batch_slice(slice_idx, axis):
        plt.figure(figsize=(8, 8))
        
        # Allow slicing along different axes to debug orientation
        if axis == 0:
            # Slicing the first dimension (usually Z if (Z, Y, X))
            plt.imshow(input_data[slice_idx, :, :], cmap='gray')
            plt.xlabel("Axis 2")
            plt.ylabel("Axis 1")
        elif axis == 1:
            # Slicing the second dimension
            plt.imshow(input_data[:, slice_idx, :], cmap='gray')
            plt.xlabel("Axis 2")
            plt.ylabel("Axis 0")
        else:
            # Slicing the third dimension
            plt.imshow(input_data[:, :, slice_idx], cmap='gray')
            plt.xlabel("Axis 1")
            plt.ylabel("Axis 0")
            
        plt.title(f"Slice {slice_idx} along Axis {axis}")
        plt.colorbar()
        plt.show()

    # 3. Create Slider
    # We default to Axis 0, but you can change the axis variable below to 1 or 2
    axis_to_scroll = 0 

    interact(
        view_batch_slice, 
        slice_idx=IntSlider(
            min=0, 
            max=input_data.shape[axis_to_scroll]-1, 
            step=1, 
            value=input_data.shape[axis_to_scroll]//2,
            description='Slice'
        ),
        axis=IntSlider(min=0, max=2, step=1, value=0, description='View Axis')
    );

## Building the pre-training process
* The technique will be simple masking (oclusion) and respective reconstruction.

In [7]:
import datetime

x = datetime.datetime.now()

date_string = f"{x.day}-{x.month}-{x.year}"
print(date_string)

# Initialize your wandb run and specify the directory
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="faking_it",
    # Set the wandb project where this run will be logged.
    project="Vesuvius",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "STU-Net",
        "dataset": "5 scrolls",
        "epochs": NUM_EPOCHS,
        "patch_size": PATCH_SIZE,
        "criterion": CRITERION,
        "lr_scheduler": LR_SCHEDULER,
        "save_dir": SAVE_DIR,
        "checkpoint_path": CHECKPOINT_PATH
    },
    name=f"Pre_training_with_5_scrolls_{date_string}", 
    dir=SAVE_DIR
)

MODEL_SAVE_PATH = join(run.dir, "model")
makedirs(MODEL_SAVE_PATH, exist_ok=True)

PREDS_PATH = join(run.dir, "preds")
makedirs(PREDS_PATH, exist_ok=True)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


19-12-2025


[34m[1mwandb[0m: Currently logged in as: [33mshadowtwin[0m ([33mfaking_it[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
def saving_logic(best_loss, avg_loss, epoch, optimizer, model):
    if best_loss > avg_loss: 
        best_loss = avg_loss
        save_path = join(MODEL_SAVE_PATH, f"model_best.pth")
        torch.save({
                'epoch': epoch,
                'model_weights': model.state_dict(),  
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_loss,
            }, save_path)
        print(f"Saved checkpoint: {save_path}")

    # Save Checkpoint
    if epoch % 10 == 0:
        save_path = join(MODEL_SAVE_PATH, f"model_epoch_{epoch}.pth")
        torch.save({
                'epoch': epoch,
                'model_weights': model.state_dict(), 
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_loss,
            }, save_path)
        print(f"Saved checkpoint: {save_path}")
    return best_loss


In [9]:
def epoch_train(model, train_loader, train_criterion, optimizer, epoch, scaler):
    model.train()
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")
    for idx, batch_dict in enumerate(pbar):
        clean_patch = batch_dict['clean_patch'].to(DEVICE)
        deform_patch = batch_dict['deform_patch'].to(DEVICE)
        dropout_mask = batch_dict['dropout_mask'].to(DEVICE)
        
        optimizer.zero_grad()
        # --- FP16 FORWARD PASS ---
        with autocast(device_type=DEVICE):
            # Forward Pass
            # The model tries to predict the CLEAN image from the DEFORMED input
            prediction = model(deform_patch)
            # Calculate Loss (Compare Prediction vs. Clean)
            train_loss = train_criterion(prediction, clean_patch, dropout_mask)

            # commented to avoid overwhelming 
            # run.log(
            #     {
            #         "train_loss": train_loss.item(),
            #         "train_step": epoch*len(train_loader)+idx
                
            #     }
            # )
            
        # Backward
        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += train_loss.item()
        pbar.set_postfix({"Loss": train_loss.item()})

    # Save a prediction
    save_vol(prediction, join(PREDS_PATH, f"{epoch}_pred_train.nii.gz"))
    save_vol(deform_patch, join(PREDS_PATH, f"{epoch}_deform_train.nii.gz"))
    save_vol(clean_patch, join(PREDS_PATH, f"{epoch}_clean_train.nii.gz"))
    save_vol(dropout_mask, join(PREDS_PATH, f"{epoch}_mask_train.nii.gz"))
    train_avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch} Finished. Avg Loss: {train_avg_loss:.6f}")
    return model, optimizer, train_avg_loss

def val(model, val_loader, val_criterion, epoch):
    model.eval()
    val_loss_sum = 0
    pbar = tqdm(val_loader, desc=f"Val epoch {epoch}/{NUM_EPOCHS}")
    for batch_dict in pbar:
        clean_patch = batch_dict['image'].to(DEVICE)
        deform_patch = batch_dict['deform_patch'].to(DEVICE)
        dropout_mask = 1 - batch_dict['tracking_mask']
        dropout_mask = dropout_mask.to(DEVICE)
        # --- FP16 FORWARD PASS ---
        with torch.no_grad():
            # Forward Pass
            # The model tries to predict the CLEAN image from the DEFORMED input
            prediction = model(deform_patch)
            # Calculate Loss (Compare Prediction vs. Clean)
            val_loss = val_criterion(prediction, clean_patch, dropout_mask)
            # commented to avoid overwhelming 
            #run.log({"val_loss": val_loss.item()})

        val_loss_sum += val_loss.item()
        pbar.set_postfix({"Loss": val_loss.item()})

    # Save a prediction
    save_vol(prediction, join(PREDS_PATH, f"{epoch}_pred.nii.gz"))
    save_vol(deform_patch, join(PREDS_PATH, f"{epoch}_deform.nii.gz"))
    save_vol(clean_patch, join(PREDS_PATH, f"{epoch}_clean.nii.gz"))
    save_vol(dropout_mask, join(PREDS_PATH, f"{epoch}_mask.nii.gz"))
    val_avg_loss = val_loss_sum / len(val_loader)
    print(f"Epoch {epoch} with validation avg Loss: {val_avg_loss:.6f}")
    return val_avg_loss

In [10]:
def train_loop(model, optimizer, train_loader, val_loader, train_criterion, val_criterion, scheduler, scaler, epoch, val_loss):
    best_val_loss = val_loss
    
    for epoch in range(epoch, NUM_EPOCHS + 1):
        # Train one epoch
        model, optimizer, train_avg_loss = epoch_train(
            model=model, 
            train_loader=train_loader,
            optimizer=optimizer, 
            epoch=epoch, 
            scaler=scaler, 
            train_criterion=train_criterion
        )

        # Run validation step
        val_avg_loss = val(
            model=model, 
            epoch=epoch, 
            val_loader=val_loader, 
            val_criterion=val_criterion
        )

        # Save in wandb
        run.log(
            {
                "epoch": epoch,
                "train_avg_loss": train_avg_loss,
                "val_avg_loss": val_avg_loss
            }
        )

        # Checking if saving 
        best_val_loss = saving_logic(
            best_loss=best_val_loss, 
            avg_loss=val_avg_loss, 
            epoch=epoch, 
            optimizer=optimizer, 
            model=model
        )

        # Applying learning rate Cosine Annealing
        scheduler.step()

In [11]:
def get_zarr_dataloader():
    transform_input = Compose(
        [
            # Load image will be handeled by the lazzy zarr loading data
            ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),
            EnsureTyped(keys=["image"])
        ]
    )


    # The Corruption Transforms
    # We want to force the model to fix heavy defects.
    transform_deform = Compose(
        [
            # Cut out 10 holes, each spatial size roughly 32x48x48
            # 35% of data loss
            RandCoarseDropoutd(
                keys=["image", "tracking_mask"],
                holes=10, 
                spatial_size=(32, 48, 48), 
                fill_value=0,
                prob=1.0 # Always apply
            ),
            # Add noise (not do it)
            #RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),
            EnsureTyped(keys=["image", "tracking_mask"])
        ]
    )

    print("Initializing Dataset...")
    ds = ZarrVolumeDataset(
        zarr_path=DATA_PATH, 
        transform_input=transform_input,
        transform_deform=transform_deform,
        patch_size=PATCH_SIZE
    )

    print("Initializing DataLoader...")
    train_loader = monai.data.DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    return train_loader


In [12]:

import numpy as np
import nibabel as nib

# Create volume of ones
volume = np.ones((128, 128, 128), dtype=np.float32)

# Create an identity affine (voxel-to-world transform)
affine = np.eye(4)

# Create NIfTI image
nii_img = nib.Nifti1Image(volume, affine)

# Save as .nii.gz
nib.save(nii_img, "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Challenge_dataset/tracking_mask.nii.gz")


In [13]:
def get_nii_dataloader():
    data_list = []

    train_images_nii = join(VAL_DATA_PATH, 'train_images_nii')
    tracking_mask_path = join(VAL_DATA_PATH, 'tracking_mask.nii.gz')
    for file_name in listdir(train_images_nii):
        complete_path = join(train_images_nii, file_name)
        
        data_list.append(
            {
                "image": complete_path,
                "tracking_mask": tracking_mask_path},

        )
        if len(data_list)>=100:
            break
    transforms = Compose(
        [   
            # Load image 
            LoadImaged(keys=["image", 'tracking_mask']),
            EnsureChannelFirstd(keys=["image", 'tracking_mask']),
            # Normalize uint8 input
            ScaleIntensityRanged(keys=["image"], a_min=0, a_max=255, b_min=0, b_max=1, clip=True),
            ResizeWithPadOrCropd(keys=["image"], spatial_size=PATCH_SIZE),
            # Make a clean copy
            # Copy the image tensor to a new key
            CopyItemsd(keys=["image"], times=1, names=["deform_patch"]),

            # Cut out 10 holes, each spatial size roughly 32x48x48
            # 35% of data loss
            RandCoarseDropoutd(
                keys=["deform_patch", 'tracking_mask'],
                holes=10, 
                spatial_size=(32, 48, 48), 
                fill_value=0,
                prob=1.0 # Always apply
            ),
            EnsureTyped(keys=["image", "deform_patch", 'tracking_mask'], track_meta=False)
        ]
    )

    print("Initializing Dataset...")
    val_ds = CacheDataset(
        data=data_list, 
        transform=transforms, 
        cache_rate=1.0,  
        num_workers=NUM_WORKERS, 
        progress=True
    )
    
    print("Initializing DataLoader...")
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=NUM_WORKERS)
    return val_loader


# Test every step before continuing

In [14]:
if DEBUG:
    try:
        model = STUNetReconstruction()
        state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu')
        model.load_state_dict(state_dict, strict=True)
        model.train()
        model.cuda()
    except Exception as e:
        print(f"Error with model loading: {e}")


In [15]:
if DEBUG:
    try:
        train_loader = get_zarr_dataloader()
        temp_zarr_iterator = iter(train_loader)
        one_batch = next(temp_zarr_iterator)
        clean_patch = one_batch['clean_patch']
        deform_patch = one_batch['deform_patch']
        describe_tensor("clean_patch", clean_patch)
        describe_tensor("deform_patch", deform_patch)

        
    except Exception as e:
        print(f"Error with zarr data loading: {e}")

In [16]:
if DEBUG:
    try:
        val_loader = get_nii_dataloader()
        temp_nii_iterator = iter(val_loader)
        one_batch = next(temp_nii_iterator)
        clean_patch = one_batch['image']
        describe_tensor("clean_patch", clean_patch)
        deform_patch = one_batch['deform_patch']
        describe_tensor("deform_patch", deform_patch)
    except Exception as e:
        print(f"Error with nii data loading: {e}")

In [17]:
if DEBUG:
    epoch = 2
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    avg_loss = 2.0
    save_path = join(MODEL_SAVE_PATH, f"model_epoch_{epoch}.pth")
    
    torch.save({
            'epoch': epoch,
            'model_weights': model.state_dict(),  # 
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_loss,
        }, save_path)
    print(f"Saved checkpoint: {save_path}")

    

In [18]:
class CriterionL1(nn.Module):
    def __init__(self, mask_weight=10.0):
        super().__init__()
        """
        Computes L1 loss.
        mask_weight: How much more we care about the masked region than the global image.
                     Default 10.0 means masked region is ~10x more important.
        """
        self.l1_loss = nn.L1Loss()
        self.mask_weight = mask_weight
    
    def forward(self, pred, target, dropout_mask):
        # 1. Calculate absolute difference
        l1_diff = torch.abs(pred - target)
        
        # 2. Masked Loss (The "Hard" Task)
        # Apply mask (1 = hole/missing, 0 = visible)
        masked_l1 = l1_diff * dropout_mask
        
        # Normalize by the number of masked pixels
        # (Sum of L1 errors in mask / Count of masked pixels)
        loss_masked = masked_l1.sum() / (dropout_mask.sum() + 1e-8)

        # 3. Global Loss (The "Stabilizer")
        # Calculates mean over the *entire* volume (masked + visible)
        loss_global = self.l1_loss(pred, target)
        
        # 4. Combine
        # Effectively: Loss = (1.0 * Masked) + (0.1 * Global)
        total_loss = loss_masked + (loss_global / self.mask_weight)
        
        # No need to divide by 2 unless you have a specific learning rate reason
        return total_loss

In [None]:
def __main__():
    # Load model
    model = STUNetReconstruction()
    state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu')
    model.load_state_dict(state_dict, strict=True)

    # Load states if resume
    if RESUME == None:
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
        epoch = 0
        val_loss = 10000
        model = model.to(DEVICE)
    else:
        checkpoint = torch.load(RESUME, map_location="cpu", weights_only=False) 

        epoch = checkpoint['epoch'] + 1 # To continue to the next epoch instead of repeating  
        model_weights = checkpoint['model_weights']   # already reconstructed
        model.load_state_dict(model_weights, strict=True)
        val_loss = checkpoint['val_loss']

        model = model.to(DEVICE)

        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print("Model loaded correctly. Resuming training...")
 
    # Data loader
    train_loader = get_zarr_dataloader()
    val_loader = get_nii_dataloader()

    # Loss
    train_criterion = CriterionL1(mask_weight=MASK_WEIGHT) #nn.L1Loss() # Sharpness preference (Better for restoration)
    val_criterion = CriterionL1(mask_weight=MASK_WEIGHT) #nn.L1Loss() # Using the same metric for evaluation

    # Defining learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, NUM_EPOCHS, eta_min=0.0, last_epoch=-1)

    # FP16 initialization 
    scaler = GradScaler()

    # Start training loop
    train_loop(
        model=model, 
        optimizer=optimizer,
        train_loader=train_loader, 
        val_loader=val_loader, 
        train_criterion=train_criterion, 
        val_criterion=val_criterion, 
        scheduler=scheduler, 
        scaler=scaler,
        epoch=epoch,
        val_loss=val_loss

    )
    
if __name__=='__main__':
    __main__()

Model loaded correctly. Resuming training...
Initializing Dataset...
Loading data from -> /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training
Initializing DataLoader...
Initializing Dataset...


Loading dataset: 100%|██████████| 100/100 [00:32<00:00,  3.05it/s]


Initializing DataLoader...


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 51/1000: 100%|██████████| 500/500 [24:32<00:00,  2.94s/it, Loss=0.0259]


Epoch 51 Finished. Avg Loss: 0.091144


Val epoch 51/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0591]


Epoch 51 with validation avg Loss: 0.106297


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 52/1000: 100%|██████████| 500/500 [24:30<00:00,  2.94s/it, Loss=0.129]  


Epoch 52 Finished. Avg Loss: 0.084838


Val epoch 52/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0498]


Epoch 52 with validation avg Loss: 0.101735


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 53/1000:  65%|██████▌   | 326/500 [16:00<08:30,  2.93s/it, Loss=0.0703]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 87/1000: 100%|██████████| 500/500 [24:33<00:00,  2.95s/it, Loss=0.0889] 


Epoch 87 Finished. Avg Loss: 0.078391


Val epoch 87/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0508]


Epoch 87 with validation avg Loss: 0.096554


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 88/1000: 100%|██████████| 500/500 [24:31<00:00,  2.94s/it, Loss=0.072]  


Epoch 88 Finished. Avg Loss: 0.078413


Val epoch 88/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0653]


Epoch 88 with validation avg Loss: 0.099044


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 89/1000: 100%|██████████| 500/500 [24:34<00:00,  2.95s/it, Loss=0.137]  


Epoch 89 Finished. Avg Loss: 0.084267


Epoch 90/1000: 100%|██████████| 500/500 [24:31<00:00,  2.94s/it, Loss=0.0911] ] 


Epoch 90 Finished. Avg Loss: 0.079921


Val epoch 90/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0923]


Epoch 90 with validation avg Loss: 0.095263
Saved checkpoint: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/checkpoints/5_scrolls_pretrain/wandb/run-20251219_094054-w862cvd3/files/model/model_epoch_90.pth


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 91/1000: 100%|██████████| 500/500 [24:31<00:00,  2.94s/it, Loss=0.0569] 


Epoch 91 Finished. Avg Loss: 0.083373


Val epoch 91/1000: 100%|██████████| 100/100 [00:32<00:00,  3.09it/s, Loss=0.0559]


Epoch 91 with validation avg Loss: 0.106192


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 94/1000: 100%|██████████| 500/500 [24:32<00:00,  2.95s/it, Loss=0.124]  


Epoch 94 Finished. Avg Loss: 0.081412


Val epoch 94/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0796]


Epoch 94 with validation avg Loss: 0.095691


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 95/1000: 100%|██████████| 500/500 [24:33<00:00,  2.95s/it, Loss=0.148]  


Epoch 95 Finished. Avg Loss: 0.081699


Val epoch 95/1000: 100%|██████████| 100/100 [00:32<00:00,  3.09it/s, Loss=0.0593]


Epoch 95 with validation avg Loss: 0.093750


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 96/1000: 100%|██████████| 500/500 [24:32<00:00,  2.95s/it, Loss=0.111]  


Epoch 96 Finished. Avg Loss: 0.079379


Val epoch 96/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0527]


Epoch 96 with validation avg Loss: 0.095356


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 97/1000: 100%|██████████| 500/500 [24:31<00:00,  2.94s/it, Loss=0.0359] 


Epoch 97 Finished. Avg Loss: 0.078887


Val epoch 97/1000:  77%|███████▋  | 77/100 [00:25<00:06,  3.35it/s, Loss=0.118]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 101/1000: 100%|██████████| 500/500 [24:32<00:00,  2.94s/it, Loss=0.116]  


Epoch 101 Finished. Avg Loss: 0.076591


Val epoch 101/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0495]


Epoch 101 with validation avg Loss: 0.099058


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 103/1000: 100%|██████████| 500/500 [24:32<00:00,  2.94s/it, Loss=0.0491] 


Epoch 103 Finished. Avg Loss: 0.079382


Val epoch 103/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0477]


Epoch 103 with validation avg Loss: 0.095568


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
Epoch 104/1000: 100%|██████████| 500/500 [24:32<00:00,  2.94s/it, Loss=0.0697] 


Epoch 104 Finished. Avg Loss: 0.081110


Val epoch 104/1000: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, Loss=0.0603]


Epoch 104 with validation avg Loss: 0.099062


  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
