# We will take the pre-trained STU-Net (Large -> Pre-trained on the TotalSegmentator cases) and further pre-train on the scrolls and fragments 

## Convert tif into zarr

In [None]:
import zarr
import tifffile
import numpy as np
from tqdm import tqdm
import os
import glob

def convert_tiff_stack_to_zarr(tiff_folder, output_zarr_path, chunk_size=(128, 128, 128)):
    # 1. Get file list (assuming sorted filenames correspond to Z-slices)
    tiff_files = sorted(glob.glob(os.path.join(tiff_folder, "*.tif")))
    
    # 2. Read first slice to get dimensions and dtype
    sample = tifffile.imread(tiff_files[0])
    z, y, x = len(tiff_files), sample.shape[0], sample.shape[1]
    dtype = sample.dtype
    
    print(f"Volume shape: ({z}, {y}, {x}), Dtype: {dtype}")

    # 3. Create Zarr v3 store on disk
    #    Using mode="w" to overwrite
    root = zarr.open(output_zarr_path, mode="w")

    # 4. Create array (v3 uses create_array instead of create_dataset)
    dset = root.create_array(
        name="volume",
        shape=(z, y, x),
        chunks=chunk_size,
        dtype=dtype
    )

    # 5. Write slices
    print("Converting...")
    for i, fname in tqdm(enumerate(tiff_files), total=z):
        img_data = tifffile.imread(fname)
        dset[i, :, :] = img_data

    print(f"Conversion complete! Saved to {output_zarr_path}")

# Example usage:
#convert_tiff_stack_to_zarr(
#    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training/example",
#    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training/data.zarr"
#)


In [None]:
import tifffile
import numpy as np
import glob
import os

# Update this path to where your TIFs are
TIFF_FOLDER = "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/notebooks/temp_raw_downloads"

# 1. Get the first file
files = sorted(glob.glob(os.path.join(TIFF_FOLDER, "*.tif")))

for file_name in files:
    # 2. Read the image
    img = tifffile.imread(file_name)
    
    if img.dtype!="uint16":
        # 3. Print stats
        print(f"--- File Analysis ---")
        print(f"Filename: {os.path.basename(file_name)}")
        print(f"Data Type (dtype): {img.dtype}")
        print(f"Shape: {img.shape}")
        print(f"Min Value: {np.min(img)}")
        print(f"Max Value: {np.max(img)}")
    
    # 4. Check for potential overflow
    if np.max(img) > 65535:
        print("\n⚠️ WARNING: Max value exceeds 65535.")
        print("You MUST normalize (divide) this data before converting to float16.")

## Building the data loader


In [4]:
from torch.utils.data import Dataset
import zarr
import numpy as np
import monai
from monai.transforms import Compose, ScaleIntensity
import torch
import torch.nn as nn
import monai
from monai.networks.nets import UNet
from monai.transforms import (
    Compose, ScaleIntensity, EnsureType, 
    RandCoarseDropout, RandGaussianNoise
)

class ZarrVolumeDataset(Dataset):
    def __init__(self, zarr_path, transform_input, transform_deform, transform_output, patch_size=(128, 128, 6), threshold=0.0):
        self.zarr_path = zarr_path
        self.transform_input = transform_input
        self.transform_deform = transform_deform
        self.transform_output = transform_output
        self.patch_size = patch_size
        self.threshold = threshold  # Value below which we consider the pixel "background"
        print(f"zarr_path: {zarr_path}")
        # --- Robust Zarr Loading ---
        try:
            root = zarr.open(zarr_path, mode='r')
            print("Try")
        except:
            store = zarr.storage.LocalStore(zarr_path, mode='r')
            root = zarr.open(store, mode='r')
            
        if hasattr(root, 'shape'):
            self.vol = root
            self.shape = root.shape
        elif 'volume' in root:
            print('volume in root')
            self.vol = root['volume']
            self.shape = root['volume'].shape
        elif '0' in root:
            self.vol = root['0']
            self.shape = root['0'].shape
        else:
            raise ValueError("Could not find volume data.")

    def __len__(self):
        return 1000 

    def __getitem__(self, index):
        # We need to re-open the store inside the worker to be safe with multiprocessing
        # However, for pure read, we can often rely on the self.vol reference if using threads.
        # But to be 100% safe against "Pickling" errors, we re-reference:
        vol = self.vol 

        z_max = max(0, self.shape[0] - self.patch_size[0])
        y_max = max(0, self.shape[1] - self.patch_size[1])
        x_max = max(0, self.shape[2] - self.patch_size[2])

        # --- THE REJECTION SAMPLING LOOP ---
        # Try up to 20 times to find a non-empty chunk
        for attempt in range(100):
            # 1. Random Coordinates
            z_start = np.random.randint(0, z_max) if z_max > 0 else 0
            y_start = np.random.randint(0, y_max) if y_max > 0 else 0
            x_start = np.random.randint(0, x_max) if x_max > 0 else 0

            # 2. Load the chunk
            patch = vol[
                z_start : z_start + self.patch_size[0],
                y_start : y_start + self.patch_size[1],
                x_start : x_start + self.patch_size[2]
            ]

            # 3. Check if it contains data
            # If the max value in this patch is greater than our threshold (0), it's valid.
            if np.max(patch) > self.threshold:
                # Found valid data! Break the loop and process it.
                break
            
            # If we are here, the patch was empty. The loop continues to the next attempt.
        
        # Note: If the loop finishes 20 times and finds nothing, it will return the LAST empty patch.
        # This prevents the code from hanging forever if the file is truly empty.

        # 4. MONAI Formatting
        patch = patch.astype(np.float32) # Ensure float for transforms
        patch = patch[np.newaxis, ...]   # Add Channel dim -> (1, Z, Y, X)
        
        #if self.transform:
        patch = self.transform_input(patch)
        clean_patch = patch.clone()
        deform_patch = self.transform_deform(patch)
        #patch = self.transform_output(patch)

        return {
            'clean_patch':clean_patch,
            'deform_patch':deform_patch
        }

# --- Setup ---
transform_input = Compose([
    ScaleIntensity(),
])

# 1. The Dataset (Your existing Zarr code)
# ds = ZarrVolumeDataset(...) 
# loader = DataLoader(ds, batch_size=4, ...)

# 2. The Corruption Transforms (Applied ONLY to the Input)
# We want to force the model to fix heavy defects.
transform_deform = Compose([
    # Cut out 1-10 holes, each spatial size roughly 16x16x2
    RandCoarseDropout(
        holes=10, 
        spatial_size=(2, 16, 16), # TODO change to 16
        fill_value=0, # 
        prob=1.0 # Always apply
    ),
    # Add noise
    #RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),
    EnsureType()
])

# --- Verification ---
# MONAI Pipeline Setup
train_transforms = Compose([
    ScaleIntensity(),
])

print("Initializing Dataset...")
ds = ZarrVolumeDataset(
    "/mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training/PHercParis4.zarr", 
    transform_input=transform_input,
    transform_deform=transform_deform,
    transform_output=None,
    patch_size=(6, 128, 128) 
)

print("Initializing DataLoader...")
loader = monai.data.DataLoader(ds, batch_size=1, num_workers=4)
data_loader = iter(loader)
first_batch = next(data_loader)
clean_patch = first_batch['clean_patch']
deform_patch = first_batch['deform_patch']
print("Fetching a batch to ensure it's not empty...")

print(f"Batch Max Value: {clean_patch.max()}")
if clean_patch.max() == 0:
    print("WARNING: The batch is still empty. Your threshold might be too high or the volume is empty.")
else:
    print("Success! Loaded a non-empty chunk.")


print(f"Success! Batch shape: {clean_patch.shape}")

Initializing Dataset...
zarr_path: /mounts/disk4_tiago_e_andre/vesuvius/Vesuvius/DataSet/Pre-training/PHercParis4.zarr
Try
volume in root
Initializing DataLoader...
Fetching a batch to ensure it's not empty...
Batch Max Value: 1.0
Success! Loaded a non-empty chunk.
Success! Batch shape: torch.Size([1, 1, 6, 128, 128])


In [3]:
first_batch = next(data_loader)
clean_patch = first_batch['clean_patch']
deform_patch = first_batch['deform_patch']

In [4]:
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
import numpy as np
import time



# 1. Extract the raw 3D volume from the batch
# MONAI batch shape is (Batch_Size, Channel, Dim1, Dim2, Dim3)
# We select Batch 0 and Channel 0
input_data = deform_patch[0, 0].cpu().numpy()

print(f"Batch Shape: {clean_patch.shape}")
print(f"Visualizing Sample Shape: {clean_patch.shape}")

# 2. Setup Interactive Viewer
def view_batch_slice(slice_idx, axis):
    plt.figure(figsize=(8, 8))
    
    # Allow slicing along different axes to debug orientation
    if axis == 0:
        # Slicing the first dimension (usually Z if (Z, Y, X))
        plt.imshow(input_data[slice_idx, :, :], cmap='gray')
        plt.xlabel("Axis 2")
        plt.ylabel("Axis 1")
    elif axis == 1:
        # Slicing the second dimension
        plt.imshow(input_data[:, slice_idx, :], cmap='gray')
        plt.xlabel("Axis 2")
        plt.ylabel("Axis 0")
    else:
        # Slicing the third dimension
        plt.imshow(input_data[:, :, slice_idx], cmap='gray')
        plt.xlabel("Axis 1")
        plt.ylabel("Axis 0")
        
    plt.title(f"Slice {slice_idx} along Axis {axis}")
    plt.colorbar()
    plt.show()

# 3. Create Slider
# We default to Axis 0, but you can change the axis variable below to 1 or 2
axis_to_scroll = 0 

interact(
    view_batch_slice, 
    slice_idx=IntSlider(
        min=0, 
        max=input_data.shape[axis_to_scroll]-1, 
        step=1, 
        value=input_data.shape[axis_to_scroll]//2,
        description='Slice'
    ),
    axis=IntSlider(min=0, max=2, step=1, value=0, description='View Axis')
);

Batch Shape: torch.Size([1, 1, 6, 128, 128])
Visualizing Sample Shape: torch.Size([1, 1, 6, 128, 128])


interactive(children=(IntSlider(value=3, description='Slice', max=5), IntSlider(value=0, description='View Axi…

## Building the pre-training process
* The technique will be simple masking (oclusion) and respective reconstruction.