In [None]:
import torch
import numpy as np
from pathlib import Path
import tifffile
import czifile
from tqdm.notebook import tqdm  # Notebook-friendly
from monai.inferers import sliding_window_inference
from tnia.deeplearning.dl_helper import quantile_normalization
import torch.nn.functional as F

np.float_ = np.float32

# --------------------------- CONFIGURATION --------------------------- #

BASE_PATH = Path(r'C:/Users/Alex/Desktop/Mailis')
DATA_DIR = BASE_PATH / "data"
MODEL_PATH = BASE_PATH / "models" / "full_brain.pth"
OUTPUT_DIR = BASE_PATH / "Brain_volume"

ROI_SIZE = 1024
BATCH_SIZE = 1

EXPORT_MODE = "classification"  # "classification" or "probability"
EXPORT_CLASSES = [0, 1, 2]       # Indices of the classes to export (0 = BG, etc.)

SUPPORTED_EXTENSIONS = [
    "*.czi", "*.tif", "*.tiff", "*.ome.tif", "*.ome.tiff", "*.btf", "*.ome.btf"
]

# --------------------------- INITIALIZATION --------------------------- #

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(MODEL_PATH, weights_only=False).to(device)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --------------------------- IMAGE LOADING --------------------------- #

def load_image(path: Path) -> np.ndarray:
    """
    Loads an image file based on its extension and returns a 3D array (Z, Y, X).
    """
    ext = path.suffix.lower()
    if ext == ".czi":
        image = czifile.imread(path)
    elif ext in [".tif", ".tiff", ".ome.tif", ".ome.tiff", ".btf", ".ome.btf"]:
        image = tifffile.imread(path)
    else:
        raise ValueError(f"Unsupported file format: {ext}")
    
    image = np.squeeze(image)
    if image.ndim != 3:
        raise ValueError(f"Expected 3D image (Z, Y, X), got shape {image.shape}")
    return image

# --------------------------- PREDICTION --------------------------- #

def predict(image_2d: np.ndarray, model: torch.nn.Module) -> np.ndarray:
    """
    Applies model prediction on a 2D slice.
    Returns normalized class probabilities as a numpy array of shape (C, H, W).
    """
    image = quantile_normalization(image_2d).astype(np.float32)
    tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = sliding_window_inference(
            tensor,
            roi_size=ROI_SIZE,
            sw_batch_size=BATCH_SIZE,
            predictor=model
        )
        probabilities = F.softmax(logits, dim=1)

    return probabilities.squeeze(0).cpu().numpy()

# --------------------------- PROCESS A SINGLE FILE --------------------------- #

def process_file(image_path: Path, file_index: int, total_files: int):
    """
    Process a full 3D volume and export a proper multichannel OME-TIFF (C, Z, Y, X).
    """
    print(f"\n[{file_index + 1}/{total_files}] Processing: {image_path.name}")
    image_stack = load_image(image_path)
    depth, height, width = image_stack.shape

    test_probs = predict(image_stack[0], model)
    total_classes = test_probs.shape[0]

    for cls in EXPORT_CLASSES:
        if cls < 0 or cls >= total_classes:
            raise ValueError(f"Invalid class index {cls}. Model output has {total_classes} classes.")

    suffix = "classification" if EXPORT_MODE == "classification" else "probability"
    save_path = OUTPUT_DIR / f"{image_path.stem}_{suffix}.ome.tif"

    output_array = np.zeros((len(EXPORT_CLASSES), depth, height, width), dtype=np.uint8)

    for z in tqdm(range(depth), desc=f"Predicting slices ({depth})", leave=False):
        prob_map = predict(image_stack[z], model)  # Shape: (C, H, W)

        if EXPORT_MODE == "probability":
            for i, cls in enumerate(EXPORT_CLASSES):
                output_array[i, z] = (prob_map[cls] * 255).astype(np.uint8)

        elif EXPORT_MODE == "classification":
            class_map = np.argmax(prob_map, axis=0)
            for i, cls in enumerate(EXPORT_CLASSES):
                output_array[i, z] = (class_map == cls).astype(np.uint8)

    # Write the full stack as a proper OME-TIFF hyperstack
    tifffile.imwrite(
        save_path,
        output_array,
        photometric='minisblack',
        metadata={'axes': 'CZYX'},
        bigtiff=True
    )

    print(f"Saved output to: {save_path}")

# --------------------------- MAIN SCRIPT --------------------------- #

def main():
    """
    Main entry point: loads all supported image files and processes them one by one.
    """
    input_files = []
    for ext in SUPPORTED_EXTENSIONS:
        input_files.extend(DATA_DIR.glob(ext))
    input_files = sorted(input_files)

    if not input_files:
        print("No input image files found in the specified folder.")
        return

    print(f"{len(input_files)} image file(s) detected.")
    for idx, image_file in enumerate(tqdm(input_files, desc="Processing files")):
        process_file(image_file, idx, len(input_files))

if __name__ == "__main__":
    main()
