In [None]:
from IPython.display import HTML, display

display(HTML('''
<style>
#rendered_cells {
    position: relative !important;
    left: auto !important;
    right: auto !important;
    top: auto !important;
    bottom: auto !important;

    display: flex !important;
    flex-direction: column !important;
    align-items: center !important;
    justify-content: flex-start !important; /* ou center si vous voulez centrer verticalement */
    
    width: 100% !important;
    max-width: 1500 !important;
    margin: 0 auto !important;
    text-align: center !important;
}
</style>
'''))

display(HTML('''
<style>
.jp-OutputArea pre, .output_subarea pre {
    text-align: center !important;
}
</style>
'''))

In [2]:
## IMPORT LIBRARIES ##

import os
import json
import numpy as np
import torch
import pandas as pd
from cellpose import models
from aicsimageio import AICSImage
from tifffile import imwrite 
from skimage.io import imread, imsave
from skimage.measure import label
from skimage.morphology import dilation, disk
from skimage.transform import resize
from skimage.util import img_as_ubyte
from skimage.segmentation import find_boundaries
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from scipy import ndimage
import warnings

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, Markdown, HTML

if torch.cuda.is_available() is True:
    try:
        import cupy as cu
        from cucim.skimage.morphology import dilation, disk
        cuda = True
    except ImportError:
        from skimage.morphology import dilation, disk
        cuda = False
else:
    from skimage.morphology import dilation, disk
    cuda = False

warnings.filterwarnings('ignore')

#### FUNCTIONS ####

def load_checkpoint(filename):
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            return json.load(f)
    return {}

def save_checkpoint(filename, data):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=2)

def reset_checkpoints():
    cp_files = ["channel_split_checkpoint.json", "segmentation_checkpoint.json", "classification_checkpoint.json"]
    for cp in cp_files:
        if os.path.exists(cp):
            os.remove(cp)

def normalize_image(image, percentile_max, nbins=512, smoothing=2):
    
     # Extract non-0 pixels
    nonzero_pixels = image[image > 0]
    if nonzero_pixels.size == 0:
        return image 

    # Compute CDF
    hist, bin_edges = np.histogram(nonzero_pixels, bins=nbins)
    cdf = np.cumsum(hist)
    cdf = cdf / cdf[-1]  # normalize between 0 and 1
    diff = np.diff(cdf)  # size nbins-1
    secdiff = np.diff(diff)  # size nbins-2

    # Smoothing of second derivative
    if smoothing > 1:
        kernel = np.ones(smoothing) / smoothing
        secdiff_smooth = np.convolve(secdiff, kernel, mode='same')
    else:
        secdiff_smooth = secdiff

    # Search for maximum
    idx_knee = np.argmax(secdiff_smooth)
    idx_knee_in_cdf = idx_knee + 1 

    if idx_knee_in_cdf < 0 or idx_knee_in_cdf >= len(bin_edges):
        idx_knee_in_cdf = 0  # fallback if out of boundaries

    # Set new min threshold
    new_min = bin_edges[idx_knee_in_cdf]

    # Set high threshold
    max_val = np.percentile(nonzero_pixels, percentile_max)

    # Normalize between 0 and 1
    norm_image = np.zeros_like(image, dtype=np.float32)

    image_clipped = np.clip(image, new_min, max_val)
    denom = (max_val - new_min) if (max_val > new_min) else 1e-9

    mask = (image_clipped >= new_min)
    norm_image[mask] = (image_clipped[mask] - new_min) / denom

    norm_image = np.clip(norm_image, 0, 1)
    
    return norm_image
def max_to_one(im):
    
    im = im / im.max()

    return im
    
## Collect label centroids

def compute_label_centroids(label_image):
    
    label_image = label_image.astype(np.int32, copy=False)
    labels_unique = np.unique(label_image)
    labels_unique = labels_unique[labels_unique != 0]

    weights = np.ones_like(label_image, dtype=np.float32)
    cm = ndimage.center_of_mass(weights, labels=label_image, index=labels_unique)

    return labels_unique, cm

## Prediction functions

class PredictionDataset(Dataset):
    """
      prop.centroid = (row, col)  =>  cx = row, cy = col  
    """

    def __init__(self, image, labels, label_ids, centroids, half_patch_size, device, config_dict):
        super().__init__()
        self.image = image
        self.labels = labels
        self.label_ids = label_ids
        self.centroids = centroids
        self.half_patch_size = half_patch_size
        self.device = device
        self.config_dict = config_dict

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

        pad = half_patch_size + 1

        self.image = np.pad(
            self.image,
            ((pad, pad), (pad, pad), (0, 0)),
            mode="constant"
        )
        if self.labels.dtype != np.int32:
            self.labels = self.labels.astype(np.int32, copy=False)
        self.labels = np.pad(
            self.labels,
            ((pad, pad), (pad, pad)),
            mode="constant"
        )

    def __len__(self):
        return len(self.label_ids)

    def __getitem__(self, idx):
        try:
           
            row, col = self.centroids[idx] 
            cx = int(row)
            cy = int(col)
            cx += self.half_patch_size + 1
            cy += self.half_patch_size + 1
            hps = self.half_patch_size
            xmin, xmax = cx - hps, cx + hps
            ymin, ymax = cy - hps, cy + hps

            # Extract patch
            patch_img = self.image[xmin:xmax, ymin:ymax, :].copy()
            patch_mask = self.labels[xmin:xmax, ymin:ymax].copy()

            label_id = self.label_ids[idx]

            # Normalize
            patch_img = max_to_one(patch_img)

            # Binarize
            patch_mask[patch_mask != label_id] = 0
            patch_mask[patch_mask == label_id] = 1

            # Dilation if stated in the config file
            do_dilate = json.loads(self.config_dict["options"]["dilation"]["dilate_mask"].lower())
            if do_dilate:
                se_size = int(self.config_dict["options"]["dilation"]["str_element_size"])
                strel = disk(se_size)
                patch_mask = dilation(patch_mask, strel)
                patch_img *= patch_mask[..., None]

            # Mask concatenation as 4th channel
            out = np.zeros((patch_img.shape[0], patch_img.shape[1], 4), dtype=np.float32)
            out[..., :3] = patch_img
            out[..., 3] = patch_mask

            # Transformation to (C, H, W) + Tensor
            out = self.transform(out)
            return out.to(self.device)

        except Exception as e:
            print(f"[Dataset] Erreur index {idx} : {e}")
            return None
pass

## Save prediction image

def save_colored_predictions_downsample(
    labels, 
    predictions, 
    used_labels, 
    myotube_image, 
    output_path, 
    factor=1
):
    if myotube_image.ndim == 3 and myotube_image.shape[2] == 3:
        myotube_image = np.mean(myotube_image, axis=2)
    H, W = myotube_image.shape

    newH, newW = H//factor, W//factor
    myotube_ds = resize(myotube_image, (newH, newW),
                        preserve_range=True,
                        anti_aliasing=True)

    myotube_rgb = np.stack([myotube_ds]*3, axis=-1)

    boundaries = find_boundaries(labels, mode='inner')
    boundary_labels = labels.copy()
    boundary_labels[~boundaries] = 0

    if len(predictions) != len(used_labels):
        raise ValueError("Number of predictions != Number of labels.")
    label_to_pred = dict(zip(used_labels, predictions))

    color_0 = [1.0, 0.0, 0.0]  # Rouge = Nuclei Out
    color_1 = [0.0, 1.0, 0.0]  # Vert = Nuclei In

    coords = np.column_stack(np.nonzero(boundary_labels))
    for (y, x) in coords:
        lb = boundary_labels[y, x]
        pred_class = label_to_pred.get(lb, None)
        if pred_class is not None:
            yd = y // factor
            xd = x // factor
            if yd < newH and xd < newW:
                if pred_class == 0:
                    myotube_rgb[yd, xd] = color_0
                else:
                    myotube_rgb[yd, xd] = color_1
    myotube_rgb_8 = img_as_ubyte(myotube_rgb)
    imsave(output_path, myotube_rgb_8)
    pass

#### ANALYSIS ####

def run_entire_analysis(
    myotube_channel, nuclei_channel, dia, percentile_max,
    half_patch_size, batch_size, split_images, run_segmentation, seg_device, run_classification, class_device, save_normalization, save_prediction, downsampling_factor,
    parent_directory_value, seg_directory_value, class_directory_value
):
    
    myotube_folder = os.path.join(parent_directory_value, "Images")
    nuclei_folder = os.path.join(parent_directory_value, "Nuclei")
    masks_folder = os.path.join(parent_directory_value, "Masks")
    predictions_output_dir = os.path.join(parent_directory_value, "Predictions")
    class_dir = os.path.dirname(class_directory_value)
    config_path = os.path.join(class_dir, "Config.json")
    log_folder = os.path.join(parent_directory_value, "Log")
    norm_dir = os.path.join(parent_directory_value, "Normalized_images")
    
    os.makedirs(log_folder, exist_ok=True)
    extension = (".tif", ".tiff")
    
## IMAGE SPLITTING ##
    
    if split_images:

        os.makedirs(myotube_folder, exist_ok=True)
        os.makedirs(nuclei_folder, exist_ok=True)
        
        channel_checkpoint_file = os.path.join(log_folder,"channel_split_checkpoint.json")
        channel_checkpoint = load_checkpoint(channel_checkpoint_file)
        processed_channel_files = channel_checkpoint.get("processed_files", [])

        display(Markdown("## CHANNEL SPLITTING ##"))
        file_list = [f for f in os.listdir(parent_directory_value) if f.lower().endswith(('.tif', '.tiff', '.ome.tiff', '.czi'))]
        total_files = len(file_list)
        
        channel_progress = widgets.IntProgress(min=0, max=total_files, description='Progress ', style={'description_width': 'initial'}, layout=widgets.Layout(width='30%'))
        display(widgets.HBox([channel_progress], layout=widgets.Layout(justify_content='center')))
    
        if not os.path.exists(parent_directory_value):
            print(f"n/No images found in {parent_directory_value}")
        else:
            for idx, file in enumerate(file_list, 1):
                if file in processed_channel_files:
                    print(f"n/{file} already processed.")
                    continue
    
                file_path = os.path.join(parent_directory_value, file)
                try:
                    if file.lower().endswith((".tiff", ".tif")):
                        image = imread(file_path)
                    elif file.lower().endswith((".ome.tiff", ".czi")):
                        aics_image = AICSImage(file_path, reconstruct_mosaic=False)
                        image = aics_image.data[0]
                        print(image.shape)
                        if image.ndim == 5:
                            image = image[0, 0].transpose((1, 2, 0))
                        elif image.ndim == 4:
                            image = image[0].transpose((1, 2, 0))
                        else:
                            continue
                    else:
                        continue
    
                    # Extract and save channels
                    if image.ndim == 3 and image.shape[2] >= 2:
                        myotube = image[myotube_channel, :, :]
                        nuclei = image[nuclei_channel, :, :]
                        myotube_save_path = os.path.join(myotube_folder, f"{os.path.splitext(file)[0]}.tif")
                        imsave(myotube_save_path, myotube.astype(np.uint16))
                        nuclei_save_path = os.path.join(nuclei_folder, f"{os.path.splitext(file)[0]}.tif")
                        imsave(nuclei_save_path, nuclei.astype(np.uint16))
                        print(f"Processed and saved channels for: {file}")
                    else:
                        print(f"Skipped {file}: not enough channels.")
                except Exception as e:
                    print(f"Failed processing {file}: {e}")
                    continue  # Move on to next file
                    
                # Update checkpoint
                processed_channel_files.append(file)
                channel_checkpoint["processed_files"] = processed_channel_files
                save_checkpoint(channel_checkpoint_file, channel_checkpoint)
                channel_progress.value = idx
    else:
        pass

## SEGMENTATION ##

    if run_segmentation: 

        os.makedirs(masks_folder, exist_ok=True)

        if seg_device == "GPU":
            gpu = False
        else :
            gpu = True
        
        seg_checkpoint_file = os.path.join(log_folder,"segmentation_checkpoint.json")
        seg_checkpoint = load_checkpoint(seg_checkpoint_file)
        segmented_files = seg_checkpoint.get("segmented_files", [])
        
        image_files = []
        for file in os.listdir(nuclei_folder):
                if file.lower().endswith((".tif", ".tiff")):
                    image_files.append(os.path.join(nuclei_folder, file))
        
        seg_model = models.CellposeModel(gpu=gpu, pretrained_model=seg_directory_value)

        display(HTML("<br><br>"))
        display(Markdown("## SEGMENTATION ##"))
        print(f"\nSegmentation device:", seg_device)
        print(f"Loading segmentation model from: {seg_directory_value}\n")

        seg_progress = widgets.IntProgress(min=0, max=len(image_files), description='Progress ',
                                   style={'description_width': 'initial'}, layout=widgets.Layout(width='30%'))
        display(widgets.HBox([seg_progress], layout=widgets.Layout(justify_content='center')))
        
        def loop(file_path, diameter):
            base_name = os.path.basename(file_path)
            name_without_ext = os.path.splitext(base_name)[0]
            image = AICSImage(file_path).data[0,0]
            masks, flows, styles = seg_model.eval(
                image, 
                diameter=diameter,
                channels=[0, 0], 
                normalize=True
            )
            if masks.max() > 0:
                mask_save_path = os.path.join(masks_folder, name_without_ext + ".tif")
                imwrite(mask_save_path, masks.astype(np.uint32))
                print(f"Processed ans saved mask for : {base_name}")
            else:
                print(f"No cells detected in {base_name}")
            
            return np.max(masks)
        
        for idx, fpath in enumerate(image_files, 1):
            if fpath in segmented_files:
                print(f"{fpath} already segmented.")
                continue
            try:
                count = loop(fpath, dia)
            except Exception as e:
                print(f"Failed segmentation for {fpath} : {e}")
                continue
            
            segmented_files.append(fpath)
            seg_checkpoint["segmented_files"] = segmented_files
            save_checkpoint(seg_checkpoint_file, seg_checkpoint)
            seg_progress.value = idx
    else:
        pass

## CLASSIFICATION ##

    if run_classification:

        os.makedirs(predictions_output_dir, exist_ok=True)
        
        class_checkpoint_file = os.path.join(log_folder,"classification_checkpoint.json")
        class_checkpoint = load_checkpoint(class_checkpoint_file)
        classified_files = class_checkpoint.get("classified_files", [])
        
        # Get image liste
        myotube_files = [f for f in os.listdir(myotube_folder) if f.lower().endswith(('.tif', '.tiff'))]
        prediction_data = []
        
        # Choose device

        with open(config_path, 'r') as f:
            config_dict = json.load(f)

        if class_device == "GPU":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            
            if device.type == "cuda":
                class_real_device = "GPU"
            else :
                class_real_device = "CPU"
                print("CUDA is not avalaible - Classification will use CPU")
        else :
            device = torch.device("cpu")
            class_real_device = class_device

        # Load Classifier

        display(HTML("<br><br>"))
        display(Markdown("## CLASSIFICATION ##"))
        print(f"\nClassification device:", class_real_device)
        print(f"Loading classification model from: {class_directory_value}\n")

        class_progress = widgets.IntProgress(min=0, max=len(myotube_files), description='Progress ', 
                                     style={'description_width': 'initial'}, layout=widgets.Layout(width='30%'))
        display(widgets.HBox([class_progress], layout=widgets.Layout(justify_content='center')))
        
        checkpoint_model = torch.load(class_directory_value, map_location = device, weights_only = False)
        model = checkpoint_model["model"].to(device)
        model.eval()
        
        for idx, img_file in enumerate(sorted(myotube_files), 1):
            if img_file in classified_files:
                print(f"{img_file} already classified.")
                continue
            
            img_path = os.path.join(myotube_folder, img_file)
            mask_path = os.path.join(masks_folder, img_file)
            if not os.path.exists(mask_path):
                print(f"No matching mask for {img_file}, skipping.")
                continue
            
            try:
                # Load Image
                image_data = imread(img_path).astype(np.float32)
                if image_data.ndim == 2:
                    image_data = np.stack([image_data] * 3, axis=-1)
                image_data = normalize_image(image_data, percentile_max)

                # Save normalization
                if save_normalization:
                    os.makedirs(norm_dir, exist_ok=True)
                    norm_image = img_as_ubyte(image_data)
                    norm_path = os.path.join(norm_dir, img_file)
                    imsave(norm_path, norm_image)
                    print(f"Normalized image saved for {img_file}")
                else:
                    pass
                        
                # Load mask and calculate centroids
                mask_data = imread(mask_path)
                label_ids, centroids = compute_label_centroids(mask_data)
                
                # Prepare dataset and dataloader
                dataset = PredictionDataset(
                    image=image_data,
                    labels=mask_data,
                    label_ids=label_ids,
                    centroids=centroids,
                    half_patch_size=half_patch_size,
                    device=device,
                    config_dict=config_dict
                )
                dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
                image_preds = []
                with torch.no_grad():
                    for batch_tensor in dataloader:
                        if batch_tensor is None or batch_tensor.size(0) == 0:
                            continue
                        outputs = model(batch_tensor)
                        preds = outputs.argmax(dim=1)
                        image_preds.extend(preds.cpu().numpy())
                
                # Save prediction image if selected
                if save_prediction:
                    base_name = os.path.splitext(img_file)[0]
                    output_path = os.path.join(predictions_output_dir, f"{base_name}_prediction.tif")
                    save_colored_predictions_downsample(
                        labels=mask_data,
                        predictions=image_preds,
                        used_labels=label_ids,
                        myotube_image=image_data,
                        output_path=output_path,
                        factor=downsampling_factor
                    )
                    print(f"Prediction image saved for {img_file}")
                
                # Calculate stats
                pred_array = np.array(image_preds)
                total_labels = len(pred_array)
                num_ones = np.sum(pred_array == 1)
                num_zeros = np.sum(pred_array == 0)
                fusion_index = (num_ones / total_labels * 100.0) if total_labels else 0.0
                prediction_data.append({
                    "Image Name": base_name,
                    "Total Number of Nuclei": total_labels,
                    "Nuclei In": num_ones,
                    "Nuclei Out": num_zeros,
                    "Fusion Index (%)": fusion_index
                })
                
                # Update checkpoint
                classified_files.append(img_file)
                class_checkpoint["classified_files"] = classified_files
                save_checkpoint(class_checkpoint_file, class_checkpoint)
                class_progress.value = idx
            
            except Exception as e:
                print(f"Failed to classify {img_file}: {e}")
                continue
    
        # Export results
        predictions_df = pd.DataFrame(prediction_data)
        output_file_path = os.path.join(parent_directory_value, "Fusion_index.xlsx")
        predictions_df.to_excel(output_file_path, index=False)
        
        print(f"\nResults exported to: {output_file_path}")
        display(Markdown("<h3 style='text-align:center'>✅ <b>Classification complete!</b></h3>"))
    
    else:
        pass
        
#### SET-UP ####

# Selection widgets
# Folder
folder_label = widgets.HTML(
    value="<b>Folder with your images</b>",
    layout=widgets.Layout(margin="5px")
)
folder_chooser = FileChooser('../', show_only_dirs=True)
folder_hbox = widgets.HBox([folder_chooser],
    layout=widgets.Layout(justify_content='center', width='100%')
)
folder_box = widgets.VBox(
    [folder_label, folder_hbox],
    layout=widgets.Layout(align_items='center', width='80%')
)

# Segmentation
seg_label = widgets.HTML(
    value="<b>Segmentation model file</b>",
    layout=widgets.Layout(margin="5px")
)
seg_chooser = FileChooser('../')
seg_hbox = widgets.HBox([seg_chooser],
    layout=widgets.Layout(justify_content='center', width='100%')
)
seg_box = widgets.VBox(
    [seg_label, seg_hbox],
    layout=widgets.Layout(align_items='center', width='80%')
)

# Classification
class_label = widgets.HTML(
    value="<b>Classification model file</b> (Make sure config.json is in the same folder)",
    layout=widgets.Layout(margin="5px")
)
class_chooser = FileChooser('../')
class_hbox = widgets.HBox([class_chooser],
    layout=widgets.Layout(justify_content='center', width='100%')
)
class_box = widgets.VBox(
    [class_label, class_hbox],
    layout=widgets.Layout(align_items='center', width='80%')
)

# Grouping blocks in a Vbox
file_choosers_box = widgets.VBox(
    [folder_box, seg_box, class_box],
    layout=widgets.Layout(align_items='center', width='100%')
)

parent_directory = folder_chooser
seg_directory = seg_chooser
class_directory = class_chooser

# Parameters widgets

style={'description_width': '200px'}
layout_widget={'width': '400px'}

myotube_channel_widget = widgets.IntText(value=0, description='Myotube channel', style=style, layout=layout_widget)
nuclei_channel_widget = widgets.IntText(value=1, description='Nuclei channel', style=style, layout=layout_widget)
dia_widget = widgets.IntText(value=24, description='Cellpose diameter', style=style, layout=layout_widget)
percentile_max_widget = widgets.FloatText(value=99.7, description='Normalization max percentile', style=style, layout=layout_widget)
half_patch_size_widget = widgets.IntText(value=100, description='Half patch size', style=style, layout=layout_widget)
batch_size_widget = widgets.IntText(value=512, description='Batch size', style=style, layout=layout_widget)
split_images_widget = widgets.Checkbox(value=True, description='Split channels', style=style, layout=layout_widget)
run_segmentation_widget = widgets.Checkbox(value=True, description='Run segmentation', style=style, layout=layout_widget)
seg_device_widget = widgets.Dropdown(
    options=['GPU', 'CPU'],
    value='GPU',
    description='Device for segmentation',
    style=style,
    layout=layout_widget
)
run_classification_widget = widgets.Checkbox(value=True, description='Run classification', style=style, layout=layout_widget)
class_device_widget = widgets.Dropdown(
    options=['GPU', 'CPU'],
    value='GPU',
    description='Device for classification',
    style=style,
    layout=layout_widget
)
save_normalization_widget = widgets.Checkbox(value=False, description='Save normalization', style=style, layout=layout_widget)
save_prediction_widget = widgets.Checkbox(value=True, description='Save prediction', style=style, layout=layout_widget)
downsampling_factor_widget = widgets.IntText(value=1, description='Downsampling factor', style=style, layout=layout_widget)

params_box = widgets.VBox(
    [
        widgets.HTML(value="<h3>Adjust parameters and launch</h3>"),
        myotube_channel_widget,
        nuclei_channel_widget,
        dia_widget,
        percentile_max_widget,
        half_patch_size_widget,
        batch_size_widget,
        split_images_widget,
        run_segmentation_widget,
        seg_device_widget,
        class_device_widget,
        run_classification_widget,
        save_normalization_widget,
        save_prediction_widget,
        downsampling_factor_widget
    ],
    layout=widgets.Layout(align_items='center', width='80%')
)

button_layout = widgets.Layout(width="200px", height="50px", margin="20px")
run_button = widgets.Button(description='Launch', layout=button_layout, button_style='success')
resume_button = widgets.Button(description='Resume', layout=button_layout, button_style='warning')

buttons_box = widgets.HBox([run_button, resume_button],
    layout=widgets.Layout(justify_content='center', width='80%')
)

output = widgets.Output()
analysis_box = widgets.VBox([buttons_box, output],
    layout=widgets.Layout(align_items='center', width='80%')
)

header = widgets.HTML(value="<h1 style='text-align:center'>Welcome to MyoFuse 1.0.0</h1>")

# Main Container
main_ui = widgets.VBox(
    [header, file_choosers_box, params_box, analysis_box],
    layout=widgets.Layout(align_items='center', justify_content='center', width='100%')
)

# Callback for lauch button
def on_run_button_click(b):
    with output:
        output.clear_output() 
        reset_checkpoints()
        try:
            # Recover parameters
            myotube_channel = myotube_channel_widget.value
            nuclei_channel = nuclei_channel_widget.value
            dia = dia_widget.value
            percentile_max = percentile_max_widget.value
            half_patch_size = half_patch_size_widget.value
            batch_size = batch_size_widget.value
            split_images = split_images_widget.value
            run_segmentation = run_segmentation_widget.value
            seg_device =  seg_device_widget.value
            run_classification = run_classification_widget.value
            class_device = class_device_widget.value
            save_normalization = save_prediction_widget.value
            save_prediction = save_prediction_widget.value
            downsampling_factor = downsampling_factor_widget.value

            # Call analysis
            run_entire_analysis(
                 myotube_channel,
                 nuclei_channel,
                 dia,
                 percentile_max,
                 half_patch_size,
                 batch_size,
                 split_images,
                 run_segmentation,
                 seg_device,
                 run_classification,
                 class_device,
                 save_prediction,
                 save_normalization,
                 downsampling_factor,
                 parent_directory.value,
                 seg_directory.value,
                 class_directory.value
            )
        except Exception as e:
            print("Erreur :", e)

# Callback for resume button
def on_resume_button_click(b):
    with output:
        output.clear_output()
        print("Resuming analysis with existing checkpoints...")

        try:
            # Recover parameters
            myotube_channel = myotube_channel_widget.value
            nuclei_channel = nuclei_channel_widget.value
            dia = dia_widget.value
            percentile_max = percentile_max_widget.value
            half_patch_size = half_patch_size_widget.value
            batch_size = batch_size_widget.value
            split_images = split_images_widget.value
            run_segmentation = run_segmentation_widget.value
            seg_device =  seg_device_widget.value
            run_classification = run_classification_widget.value
            class_device = class_device_widget.value
            save_normalization = save_prediction_widget.value
            save_prediction = save_prediction_widget.value
            downsampling_factor = downsampling_factor_widget.value
            
            run_entire_analysis(
                myotube_channel,
                nuclei_channel,
                dia,
                percentile_max,
                half_patch_size,
                batch_size,
                split_images,
                run_segmentation,
                seg_device,
                run_classification,
                class_device,
                save_prediction,
                save_normalization,
                downsampling_factor,
                parent_directory.value,
                seg_directory.value,
                class_directory.value
            )
            
        except Exception as e:
            print("Erreur :", e)

# Attach callbacks to buttons
run_button.on_click(on_run_button_click)
resume_button.on_click(on_resume_button_click)

display(main_ui)

VBox(children=(HTML(value="<h1 style='text-align:center'>Welcome to MyoFuse 1.0.0</h1>"), VBox(children=(VBox(…