In [None]:
import os
import numpy as np
import xml.etree.ElementTree as ET
from czifile import imread as read_czi
from tifffile import imread, imwrite, TiffFile
from tqdm import tqdm
import warnings
from collections import defaultdict

# =======================
# CONFIG
# =======================
CHANNEL_INDEX = 0        # <- choisis l'index du channel (0,1,2,...) à exporter
CHANNEL_NAME  = None     # <- ou un nom ("GFP", "DAPI", ...) si présent dans les métadonnées
Z_STEP        = 25       # <- intervalle entre tranches Z (1 = toutes les tranches)
valid_exts    = ('.czi', '.tif', '.tiff', '.ome.tif', '.bigtif')

def zero_pad(i, width=3):
    return f"{i:0{width}d}"

warnings.simplefilter("ignore", UserWarning)

# Si tu ne l'as pas déjà défini ailleurs, décommente et adapte ces chemins :
# parent_path = r'C:\Users\Alex\Desktop\Mailis_lightsheet'
# data_path = os.path.join(parent_path, 'data')                # Fichiers bruts (CZI/TIFF)
# extracted_path = os.path.join(parent_path, 'extracted_tiff') # Tranches extraites
# cropped_path = os.path.join(parent_path, 'cropped_tiff')     # 2D finales recadrées

os.makedirs(extracted_path, exist_ok=True)
os.makedirs(cropped_path, exist_ok=True)

# =======================
# Helpers métadonnées / channels
# =======================
try:
    from czifile import CziFile
except Exception:
    CziFile = None

def get_channel_names(path):
    """Retourne la liste des noms de channels (si dispo) pour OME-TIFF/CZI, sinon None."""
    names = []
    try:
        if path.lower().endswith('.czi') and CziFile is not None:
            with CziFile(path) as czi:
                meta = czi.metadata()
            if isinstance(meta, bytes):
                meta = meta.decode('utf-8', errors='ignore')
            root = ET.fromstring(meta)
            for ch in root.findall(".//{*}Channel"):
                nm = ch.get('Name') or ch.get('Id') or ch.get('ID')
                if nm:
                    names.append(nm)
        else:
            with TiffFile(path) as tf:
                xml = tf.ome_metadata
            if xml:
                root = ET.fromstring(xml)
                for ch in root.findall(".//{*}Channel"):
                    nm = ch.get('Name') or ch.get('Id') or ch.get('ID')
                    if nm:
                        names.append(nm)
    except Exception:
        pass
    return names or None

def choose_channel_index(names, fallback=0):
    """Mappe CHANNEL_NAME en index si possible, sinon retourne CHANNEL_INDEX (ou fallback)."""
    if CHANNEL_NAME and names:
        for i, n in enumerate(names):
            if str(n).strip().lower() == str(CHANNEL_NAME).strip().lower():
                return i
        print(f"[info] CHANNEL_NAME='{CHANNEL_NAME}' introuvable dans {names}. "
              f"On utilise l'index {CHANNEL_INDEX}.")
    return CHANNEL_INDEX if CHANNEL_INDEX is not None else fallback

# =======================
# Étape 1 : Extraction 2D (mono-channel) depuis CZI/TIFF
# =======================
image_files = [f for f in os.listdir(data_path) if f.lower().endswith(valid_exts)]
print(f"\n Found {len(image_files)} image files to process.\n")

for file_idx, filename in enumerate(sorted(image_files), 1):
    print(f"[{file_idx}/{len(image_files)}] Extracting slices from: {filename}")
    in_path = os.path.join(data_path, filename)

    # Lecture
    if filename.lower().endswith('.czi'):
        vol = read_czi(in_path)
    else:
        vol = imread(in_path)

    vol = np.squeeze(vol)  # retire dims singleton éventuelles

    # Sélection d'un channel + mise en (Z,Y,X)
    # Cas communs : (C,Z,Y,X) ou (Z,C,Y,X) -> on choisit un seul channel
    if vol.ndim == 4:
        ch_names = get_channel_names(in_path)
        ch_idx = choose_channel_index(ch_names, fallback=0)

        # Heuristique : l'axe "C" a souvent une petite taille (<= 8) comparée à Z/Y/X
        if vol.shape[0] <= 8 and vol.shape[1] > 8:
            # (C, Z, Y, X)
            msg = (ch_names[ch_idx] if ch_names and ch_idx < len(ch_names) else f"index {ch_idx}")
            print(f"  -> Interprétation (C,Z,Y,X), channels={vol.shape[0]}, choix={msg}")
            vol = vol[ch_idx]          # -> (Z, Y, X)
        elif vol.shape[1] <= 8 and vol.shape[0] > 8:
            # (Z, C, Y, X)
            msg = (ch_names[ch_idx] if ch_names and ch_idx < len(ch_names) else f"index {ch_idx}")
            print(f"  -> Interprétation (Z,C,Y,X), channels={vol.shape[1]}, choix={msg}")
            vol = vol[:, ch_idx]       # -> (Z, Y, X)
        else:
            # Cas non standard : on aplatit (pas de sélection précise de channel possible ici)
            print("  [warn] Format 4D non standard, fallback flatten. Sélection précise de channel indisponible.")
            z = vol.shape[0] * vol.shape[1]
            vol = vol.reshape((z, vol.shape[-2], vol.shape[-1]))

    # Si image 2D (Y,X), on fabrique une pile Z de taille 1
    if vol.ndim == 2:
        vol = vol[None, ...]

    if vol.ndim != 3:
        raise ValueError(f"Unexpected image shape: {vol.shape} in {filename} (expect 3D (Z,Y,X)).")

    # Export des tranches Z mono-channel
    for i in tqdm(range(0, vol.shape[0], Z_STEP), desc='Saving Z-slices'):
        image2d = vol[i, :, :]
        out_name = f"{os.path.splitext(filename)[0]}_z{zero_pad(i)}.tif"
        out_path = os.path.join(extracted_path, out_name)
        imwrite(out_path, image2d.astype(np.uint16))

print("\n Step 1 complete: 2D slices (single-channel) extracted from image files.")

# =======================
# Étape 2 : Regroupement par préfixe + cropping commun (inchangé)
# =======================
# Group by prefix (e.g. M20E1_z025 → M20E1)
stacks = defaultdict(list)
for fname in sorted(os.listdir(extracted_path)):
    if fname.lower().endswith('.tif'):
        prefix = fname.split('_z')[0]
        stacks[prefix].append(os.path.join(extracted_path, fname))

# Find smallest shape among all images
min_shape = None
for file_list in stacks.values():
    for path in file_list:
        img = imread(path)
        if min_shape is None:
            min_shape = img.shape
        else:
            min_shape = np.minimum(min_shape, img.shape)

print(f"\n  Cropping all images to size: {min_shape}")

# Crop and save to final output folder
for prefix, files in stacks.items():
    for f in sorted(files):
        img = imread(f)
        cropped = img[:min_shape[0], :min_shape[1]]
        out_name = os.path.basename(f)
        out_path = os.path.join(cropped_path, out_name)
        imwrite(out_path, cropped.astype(np.uint16))

print(f"\n Step 2 complete: All cropped images saved to: {cropped_path}")
