<a href="https://colab.research.google.com/github/Itamar-Horowitz/real-time-AutoDS/blob/main/AutoDS/Realtime_AutoDS_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **AutoDS**

---

<font size = 4> Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).

<font size = 4> AutoDS is an extension of Deep-STORM automating the reconstruction process and aleviating the need in human intervension. This is done by automatic detection of the experimental condition in the analyzed videos and automatic selection of a Deep-STORM model out of a set of pre-trained model for the data processing.

<font size = 4> Additionally, AutoDS pipeline splits each input frame into patches and enables processing of different regions in the field-of-view with different models. This mechanism led to an improvment in the reconstruction quality beyond the capabilities of Deep-STORM.


# **Before getting started**
---
<font size = 4> This notebook contains the code required only for inference of SMLM data using a set of pre-trained Deep-STORM models. For model training please follow this [link](https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/blob/main/AutoDS/AutoDS_training.ipynb).

# **Run configuration**
---
<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the found localizations csv.

<font size = 4>**`threshold`:** This paramter determines threshold for local maxima finding. A higher `threshold` will result in less localizations. **DEFAULT: 10**

<font size = 4>**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**

<font size = 4>**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**

<font size = 4>**`num_patches`:** Determines the number of patches in each row and each column after splitting the frames to patches. The total number of patches will be num_patches<sup>2</sup>. **DEFAULT: 4**

<font size = 4>**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 1**

<font size = 4>**The following parameters are relevant only if `interpolate_based_on_imaging_parameters` is checked:**

<font size = 4> - **`pixel_size` [nm]:** the pixels size of the analyzed video. **DEFAULT: 107**

<font size = 4> - **`wavelength` [nm]:** the emission wavelength of the analyzed video. **DEFAULT: 715**

<font size = 4> - **`numerical_aperture`:** the optical setup numerical aperture of the analyzed video. **DEFAULT: 1.49**

<font size = 4> - **`chunk_size`:** determine the number of patches that will be analyzed in each prediction iteration. This parameter is used for managing compute resources in Google Colab. If you are facing crashes due to RAM memory limitation, decrease the number of patches per chunk. **DEFAULT: 10000**

# **Mount Google Drive**
---
Running the next cell will mount your google drive

In [None]:
#@markdown Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

# **One-click inference**
---
Running the next cell will perform the following steps:
1. Installating require dependencies
2. Requesting GPU access
3. Downloading pre-trained models
4. Running the inference based on your configuration

In [None]:
Notebook_version = '1.2'
Network = 'AutoDS'

# Import keras modules and libraries from tensorflow.keras
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Activation, UpSampling2D, Conv2D, MaxPooling2D, BatchNormalization, Layer
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import backend as K
from tensorflow.keras import losses
from tensorflow.keras.optimizers import Adam

# Other libraries
import scipy.optimize as opt
import scipy.io as sio
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile as tiff
from numpy.lib.stride_tricks import sliding_window_view
from scipy.ndimage import gaussian_laplace, maximum_filter, binary_dilation
from scipy.signal import fftconvolve
from skimage.morphology import white_tophat, disk
import h5py
import cv2
from skimage import io
import sys, os, traceback
import csv
from PIL import Image
from PIL.TiffTags import TAGS
import math
from skimage.feature import peak_local_max
from scipy.ndimage import gaussian_filter, zoom
from tqdm import tqdm
from contextlib import contextmanager

# Create a variable to get and store relative base path
base_path = os.getcwd()

import io, json, zipfile, hashlib, shutil, urllib.request
from pathlib import Path

def _printer():
    # use global log() if you defined QUIET/log earlier; else print
    return log if 'log' in globals() else print

def _sha256(path):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            h.update(chunk)
    return h.hexdigest()

def _download(url, dst_path):
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    _printer()(f"[models] downloading: {url}")
    urllib.request.urlretrieve(url, dst_path)

def _flatten_if_needed(target_dir, required_files):
    """
    If the extracted ZIP created a nested top-level folder (e.g., target_dir/diff_1/*),
    but we expect files directly under target_dir, move them up one level.
    """
    present = all(os.path.exists(os.path.join(target_dir, f)) for f in required_files)
    if present:
        return

    # look for a single subdir containing the stuff
    subdirs = [d for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
    if len(subdirs) == 1:
        candidate = os.path.join(target_dir, subdirs[0])
        # if moving would fix it, move contents up
        if all(os.path.exists(os.path.join(candidate, f)) for f in required_files):
            for name in os.listdir(candidate):
                shutil.move(os.path.join(candidate, name), os.path.join(target_dir, name))
            # remove now-empty subdir
            try:
                os.rmdir(candidate)
            except OSError:
                pass

def ensure_models(model_names, target_root="/content/AutoDS_models", model_manifest=None):
    """
    model_manifest schema (choose ONE per model):
      # ZIP asset per model (recommended)
      {
        "diff_1": {
          "zip_url": "<direct zip url>",
          "sha256":  "<optional sha256 of the zip>",
          "contains": ["best_weights.h5", "model_metadata.mat"]
        },
        ...
      }

      # Raw files (no zip)
      {
        "diff_1": {
          "file_urls": {
            "best_weights.h5": "<direct file url>",
            "model_metadata.mat": "<direct file url>"
          },
          "file_sha256": {             # optional, per-file
            "best_weights.h5": "<sha256>",
            "model_metadata.mat": "<sha256>"
          },
          "contains": ["best_weights.h5", "model_metadata.mat"]
        },
        ...
      }
    """
    if model_manifest is None:
        raise ValueError("ensure_models: model_manifest must be provided.")
    os.makedirs(target_root, exist_ok=True)

    for m in model_names:
        cfg = model_manifest[m]
        mdir = os.path.join(target_root, m)
        need_fetch = False

        # fast-path: check presence
        req = cfg.get("contains", [])
        if not os.path.isdir(mdir):
            need_fetch = True
        else:
            for f in req:
                if not os.path.exists(os.path.join(mdir, f)):
                    need_fetch = True
                    break

        if not need_fetch:
            _printer()(f"[models] found: {m}")
            continue

        _printer()(f"[models] preparing: {m}")
        os.makedirs(mdir, exist_ok=True)

        if "zip_url" in cfg:
            # ZIP flow
            zip_url = cfg["zip_url"]
            zip_path = os.path.join(target_root, f"{m}.zip")
            _download(zip_url, zip_path)

            if "sha256" in cfg:
                digest = _sha256(zip_path)
                if digest != cfg["sha256"]:
                    raise ValueError(f"SHA256 mismatch for {m} zip. expected {cfg['sha256']} got {digest}")

            with zipfile.ZipFile(zip_path, 'r') as zf:
                zf.extractall(mdir)
            os.remove(zip_path)

            # handle nested folder cases
            _flatten_if_needed(mdir, req)

        elif "file_urls" in cfg:
            # Per-file flow
            file_urls = cfg["file_urls"]
            file_sha = cfg.get("file_sha256", {})
            for fname, url in file_urls.items():
                dst = os.path.join(mdir, fname)
                _download(url, dst)
                if fname in file_sha:
                    digest = _sha256(dst)
                    if digest != file_sha[fname]:
                        raise ValueError(f"SHA256 mismatch for {m}/{fname}. expected {file_sha[fname]} got {digest}")
        else:
            raise ValueError(f"Model {m} manifest must have either 'zip_url' or 'file_urls'.")

        # Final presence check
        for f in req:
            if not os.path.exists(os.path.join(mdir, f)):
                raise FileNotFoundError(f"Model {m} missing required file after fetch: {f}")

        _printer()(f"[models] ready: {m}")

    return target_root

# --- Quiet/Preview flags ------------------------------------------------------
QUIET = False            # no training/inference chatter unless set to False
HEADLESS_PREVIEW = True  # set True if you want to see the preview figures

def log(*args, **kwargs):
    if not QUIET:
        print(*args, **kwargs)

# Define where to fetch each model
# Replace the example zip URLs with your actual GitHub Release (or other) asset URLs.
model_names = ['diff_1', 'diff_2', 'diff_3', 'diff_4']
MODEL_MANIFEST = {
    "diff_1": {
        "file_urls": {
            "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/best_weights.h5",
            "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/model_metadata.mat",
        },
        "contains": ["best_weights.h5", "model_metadata.mat"]
    },
    "diff_2": {
        "file_urls": {
            "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/best_weights.h5",
            "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/model_metadata.mat",
        },
        "contains": ["best_weights.h5", "model_metadata.mat"]
    },
    "diff_3": {
        "file_urls": {
            "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/best_weights.h5",
            "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/model_metadata.mat",
        },
        "contains": ["best_weights.h5", "model_metadata.mat"]
    },
    "diff_4": {
        "file_urls": {
            "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/best_weights.h5",
            "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/model_metadata.mat",
        },
        "contains": ["best_weights.h5", "model_metadata.mat"]
    },
}
# Download (only if missing) and set prediction_model_path accordingly
prediction_model_path = ensure_models(model_names, target_root="/content/AutoDS_models", model_manifest=MODEL_MANIFEST)

def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):
  n_locs = xc_array.shape[0]
  xc_array_Corr = np.empty(n_locs)
  yc_array_Corr = np.empty(n_locs)

  for loc in range(n_locs):
    xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc] - 1]
    yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc] - 1]

  return (xc_array_Corr, yc_array_Corr)

def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):
  w = image_size[0]
  h = image_size[1]
  locImage = np.zeros((image_size[0],image_size[1]) )
  n_locs = len(xc_array)

  for e in range(n_locs):
    locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1

  return locImage

def estimate_drift_com_nm(img1, img2, pixel_size_nm, sigma=1.0, patch_radius=3):
    # Smooth images
    img1_smooth = gaussian_filter(img1.astype(np.float32), sigma=sigma)
    img2_smooth = gaussian_filter(img2.astype(np.float32), sigma=sigma)

    # Cross-correlation
    corr = fftconvolve(img1_smooth, img2_smooth, mode='same')

    # Define center of the image
    center_y, center_x = np.array(corr.shape) // 2

    # Define a crop region around the center
    y_min = max(0, center_y - patch_radius)
    y_max = min(corr.shape[0], center_y + patch_radius + 1)
    x_min = max(0, center_x - patch_radius)
    x_max = min(corr.shape[1], center_x + patch_radius + 1)

    # Crop around center
    patch = corr[y_min:y_max, x_min:x_max]

    # Find subpixel center of mass in cropped patch
    y_grid, x_grid = np.meshgrid(
        np.arange(y_min, y_max), np.arange(x_min, x_max), indexing='ij'
    )

    total = np.sum(patch)
    if total == 0:
        return 0.0, 0.0

    y_com = np.sum(patch * y_grid) / total
    x_com = np.sum(patch * x_grid) / total

    # Drift relative to center, in pixels
    dy_px = y_com - center_y
    dx_px = x_com - center_x

    # Limit the drift to patch_radius
    if abs(dy_px) > patch_radius or abs(dx_px) > patch_radius:
        return 0.0, 0.0  # or raise an exception

    # Convert to nanometers
    dy_nm = dy_px * pixel_size_nm
    dx_nm = dx_px * pixel_size_nm

    return dy_nm, dx_nm

def gaussian_interpolation_batch(data_batch, scale, sigma=1):
    """
    Applies Gaussian interpolation (smoothing and upsampling) to a batch of images.

    Parameters:
    - data_batch: A numpy array of shape (batch_size, height, width), where each entry is an image.
    - scale: The scaling factor for upsampling.
    - sigma: The standard deviation for the Gaussian filter (default is 1).

    Returns:
    - upsampled_data_batch: A numpy array containing the upsampled images.
    """
    upsampled_data_batch = []

    for data in data_batch:
        # Apply Gaussian filter to each image in the batch
        smoothed_data = gaussian_filter(data, sigma=sigma)

        # Upsample the smoothed image
        upsampled_data = zoom(smoothed_data, scale, order=3)  # Using cubic interpolation for smooth upsampling
        upsampled_data_batch.append(upsampled_data)

    # Convert the list of upsampled images back into a numpy array
    return np.array(upsampled_data_batch)

def interpolate_frames(tiff_stack, model_pixel_size, current_pixel_size,
                             model_wavelength, current_wavelength, model_NA, current_NA):
    # Compute ratio
    if model_pixel_size is None: model_pixel_size = current_pixel_size
    if model_wavelength is None: model_wavelength = current_wavelength
    if model_NA is None: model_NA = current_NA
    if current_wavelength is None: current_wavelength = model_wavelength = 1
    if current_NA is None: current_NA = model_NA = 1

    if len(tiff_stack.shape) == 2:
        tiff_stack = tiff_stack[None, :, :]

    scale_ratio_sq = (0.21 * model_wavelength / model_NA) ** 2 - (0.21 * current_wavelength / current_NA) ** 2
    if (scale_ratio_sq) > 0:
        scale_ratio = np.sqrt(scale_ratio_sq) / model_pixel_size
        interpolated_stack = np.stack([gaussian_filter(tiff_stack[i], scale_ratio) for i in range(tiff_stack.shape[0])])
    else:
        zoom_factors = (1, model_pixel_size / current_pixel_size, model_pixel_size / current_pixel_size)
        interpolated_stack = zoom(tiff_stack.astype(np.float32), zoom_factors, order=3)

    return interpolated_stack.astype(np.float32, copy=False)  # <-- ensure float32

def ChooseNetByDifficulty_2025(density, SNR):
    num_models = 4
    norm_density = np.max([np.min([int(np.round(2 * density)), num_models-1]), 0])
    norm_SNR = num_models - 1 - np.max([np.min([SNR//2, num_models - 1]), 0])
    return int(np.round((norm_SNR + norm_density) / 2))

def reconstruct_patches_2025(Images, patch_ind, frame_numbers, weights_file, num_patches, overlap, number_of_frames,
                             thresh=0.1, neighborhood_size=3, use_local_avg=False, upsampling_factor=8, pixel_size=None,
                             batch_size=1):
    pixel_size_hr = pixel_size / upsampling_factor

    # Convert Images to float32 Tensor and move to GPU
    Images = tf.convert_to_tensor(Images, dtype=tf.float32)
    if Images.ndim == 2:
        Images = tf.expand_dims(Images, axis=0)  # Ensure 3D shape
    K_frames, M, N = Images.shape

    # Determine dimensions of each predicted (cropped) patch.
    patch_height = M * upsampling_factor - 2 * overlap
    patch_width = N * upsampling_factor - 2 * overlap

    # Create full image tensor on GPU
    reconstructed_image = np.zeros((patch_height * num_patches, patch_width * num_patches), dtype=np.float32)

    # Prepare lists for detections
    recon_xind, recon_yind, frame_index, confidence_list = [], [], [], []

    # Load the model on the GPU
    with tf.device('/GPU:0'):
        model = build_model_upsample((M, N, 1), lr=1e-3, upsampling_factor=upsampling_factor)
        model.load_weights(weights_file)

        # Create the post-processing layer
        max_layer = Maximafinder(thresh, neighborhood_size, use_local_avg)

        n_batches = int(np.ceil(K_frames / batch_size))
        for b in range(n_batches):
            start = b * batch_size
            end = min(K_frames, start + batch_size)
            nF = end - start

            # --- Move input batch to GPU ---
            batch_imgs = Images[start:end]  # Shape: (nF, M, N)

            # --- Run prediction on GPU ---
            predicted_density = model(batch_imgs, training=False)
            predicted_density = tf.nn.relu(predicted_density - 0.5).numpy()  # Faster than `predicted_density[predicted_density < 0] = 0`

            # Crop off extra overlap
            cropped_pred = predicted_density[:, overlap:-overlap, overlap:-overlap, 0]

            # --- Post-processing on GPU ---
            bind, xind, yind, conf = max_layer(predicted_density[:, overlap:-overlap, overlap:-overlap])

            # Convert tensors to NumPy (only when needed)
            bind_np, xind_np, yind_np, conf_np = bind.numpy(), xind.numpy(), yind.numpy(), conf.numpy() / L2_weighting_factor

            # --- Place each patch in reconstructed image ---
            for i in range(nF):
                p_ind = patch_ind[start + i]
                y1 = patch_height * (p_ind // num_patches)
                x1 = patch_width * (p_ind % num_patches)

                # Use TensorFlow addition instead of NumPy
                reconstructed_image[y1:y1 + patch_height, x1:x1 + patch_width] += (cropped_pred[i] / number_of_frames)

                # Collect detections
                det_idx = np.where(bind_np == i)[0]
                if det_idx.size:
                    recon_xind.extend((x1 + xind_np[det_idx]).tolist())
                    recon_yind.extend((y1 + yind_np[det_idx]).tolist())
                    frame_index.extend([frame_numbers[start + i] + 1] * det_idx.size)
                    confidence_list.extend(conf_np[det_idx].tolist())

    # Convert coordinates to physical units
    xind_final = (np.array(recon_xind) * pixel_size_hr).tolist()
    yind_final = (np.array(recon_yind) * pixel_size_hr).tolist()

    return reconstructed_image, [frame_index, xind_final, yind_final, confidence_list]


def split_image_to_patches(img, num_patches, overlap):
    # Determine the non-overlapping patch size.
    H, W = img.shape
    patch_h = H // num_patches
    patch_w = W // num_patches

    # Pad the image so that border patches have the proper overlap.
    padded_img = np.pad(img, ((overlap, overlap), (overlap, overlap)), mode='reflect')

    # Define the window (patch) shape including overlap.
    window_shape = (patch_h + 2 * overlap, patch_w + 2 * overlap)

    # Create a sliding window view of the padded image.
    # The sliding window view will have shape:
    # (padded_H - window_shape[0] + 1, padded_W - window_shape[1] + 1, window_shape[0], window_shape[1])
    patches_view = sliding_window_view(padded_img, window_shape)

    # Sample patches at strides equal to the basic patch size.
    patches_array = patches_view[0::patch_h, 0::patch_w, :, :]

    # Flatten the 2D grid of patches (row-major order) into a list.
    num_rows, num_cols, ph, pw = patches_array.shape
    patches_list = [patches_array[i, j].copy() for i in range(num_rows) for j in range(num_cols)]

    return patches_list

def gauss2d(xy, offset, amp, x0, y0, sigma):
  x, y = xy
  return offset + (amp * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2) - ((y - y0) ** 2) / (2 * sigma ** 2)))

def extract_all_features(Images, FOV_size, pixel_size):
    M, N = FOV_size
    patch_size = 7
    xy = np.zeros([2, int(patch_size ** 2)])
    for i1 in range(patch_size):
        for j1 in range(patch_size):
            xy[:, int(i1 + patch_size * j1)] = [i1, j1]

    if(len(Images.shape) == 2):
        Images = Images[None, :, :]

    peaks_first_frame = peak_local_max(Images[0],
                                       min_distance=patch_size // 2,
                                       threshold_abs=np.mean(Images[0]) + np.std(Images[0]))
    peaks_for_analysis = []
    cnt = 0
    for i in range(len(peaks_first_frame)):
        if (np.sum(np.abs(peaks_first_frame[:, 0] - peaks_first_frame[i, 0]) +
                   np.abs(peaks_first_frame[:, 1] - peaks_first_frame[i, 1]) < 2) == 1):
            peaks_for_analysis.append([peaks_first_frame[cnt, 0], peaks_first_frame[cnt, 1]])
            cnt += 1
            if (cnt > 100):
                break

    peaks_for_analysis = np.array(peaks_for_analysis)
    number_of_PSFs_to_fit = np.min([100, peaks_for_analysis.shape[0]])

    sigmas_list = []
    gaussian_amp_list = []
    for i in range(number_of_PSFs_to_fit):
        down = np.max([0, peaks_for_analysis[i, 0] - patch_size // 2])
        up = np.min([M - 1, peaks_for_analysis[i, 0] + patch_size // 2])
        left = np.max([0, peaks_for_analysis[i, 1] - patch_size // 2])
        right = np.min([N - 1, peaks_for_analysis[i, 1] + patch_size // 2])
        zobs = (Images[0][down:up + 1, left:right + 1]).reshape(1, -1).squeeze()
        try:
            guess = [np.median(zobs), np.median(zobs), patch_size // 2, patch_size // 2, 1]
            bounds = ([0, 0, 0, 0, 0.5], [np.inf, np.inf, patch_size, patch_size, patch_size // 2])
            pred_params, uncert_cov = opt.curve_fit(gauss2d, xy, zobs, p0=guess, bounds=bounds)
        except Exception as e:
            continue
        fit = gauss2d(xy, *pred_params)
        if (1 - np.sqrt(np.mean((zobs / np.max(zobs) - fit / np.max(fit)) ** 2)) < 0.9):
           continue
        sigmas_list.append(pred_params[4])
        gaussian_amp_list.append(pred_params[1])

    if(len(sigmas_list) < 1):
        log("Did not find emitters for sigma estimation! setting sigma to 1 pixel")
        sigma = 1
        sigma_std = 0
    else:
        sigma = np.mean(sigmas_list)
        sigma_std = np.std(sigmas_list)

    mean_noise_list = []
    std_noise_list = []
    emitter_density_list = []
    for i in range(np.min([Images.shape[0], 100])):
        curr_mean_noise, curr_std_noise, signal_amp, curr_emitter_density = extract_features_frame(Images[i],
                                                                                                   pixel_size,
                                                                                                   verbose=False)
        mean_noise_list.append(curr_mean_noise)
        std_noise_list.append(curr_std_noise)
        emitter_density_list.append(curr_emitter_density)

    ADC_offset = np.mean(mean_noise_list)
    ReadOutNoise_ADC = np.mean(std_noise_list)
    gaussian_amp_mean = np.mean(gaussian_amp_list)
    gaussian_amp_std = np.std(gaussian_amp_list)
    emitter_density = np.mean(emitter_density_list)

    return ADC_offset, ReadOutNoise_ADC, gaussian_amp_mean, gaussian_amp_std, \
           emitter_density, sigma, sigma_std

def remove_zero_padding(image):
    image_array = np.array(image)
    non_zero_rows = np.where(image_array.sum(axis=1) != 0)
    non_zero_cols = np.where(image_array.sum(axis=0) != 0)
    cropped_image = image_array[non_zero_rows[0][0]:non_zero_rows[0][-1]+1, non_zero_cols[0][0]:non_zero_cols[0][-1]+1]
    return cropped_image

def subtract_smooth_background(im, sigma=3):
    return im - gaussian_filter(im, sigma)

def subtract_background_tophat(im, radius=15):
    return white_tophat(im, footprint=disk(radius))

def extract_features_frame(OrigImage, pixel_size, psf_sigma, offset=None, verbose=False):
    M, N = OrigImage.shape

    Image = OrigImage - gaussian_filter(OrigImage, sigma=5)

    if(offset is not None):
        if(np.percentile(gaussian_filter(Image, 2), 99) < 2 * Image.mean() or np.percentile(OrigImage, 99) < 2 * offset):
            if(verbose):
                plt.figure(figsize=(7, 7))
                plt.title("SNR is too low - ignoring patch")
                plt.imshow(OrigImage)
                plt.show()
            return np.mean(OrigImage), np.std(OrigImage), 0, 0

    log_image = -gaussian_laplace(Image, sigma=psf_sigma)  # negative = blob-like peaks

    # Local maxima filtering
    neighborhood_size = 3
    local_max = (log_image == maximum_filter(log_image, size=neighborhood_size))

    # Compute the threshold
    amp_threshold = np.mean(Image) + 0.5 * (np.percentile(Image, 99) - np.mean(Image))

    # Apply intensity threshold
    pcntl_threshold = np.percentile(Image, 85)
    binary_mask = np.logical_and(local_max, Image > np.max([amp_threshold, pcntl_threshold]))

    dilated_mask = binary_dilation(binary_mask, structure=np.ones((5, 5)))
    noise_mask = np.ones_like(binary_mask)
    noise_mask[dilated_mask] = 0

    if(np.sum(binary_mask) > 0):
        ADC_offset = np.mean(OrigImage[noise_mask])
        ReadOutNoise_ADC = np.std(OrigImage[noise_mask])
        Signal_amp = np.mean(OrigImage[binary_mask == 1])
        emitter_density = (10 ** 6) * float(np.sum(binary_mask)) / (M * N * pixel_size ** 2)
    else:
        if(verbose):
            log("Didn't find any emitters")
        return np.mean(OrigImage), np.std(OrigImage), 0, 0

    if(Signal_amp / ADC_offset < 2.5):
        if(emitter_density > 2):
            if(verbose):
                plt.figure(figsize=(8, 8))
                plt.title("SNR is too low for emitter density estimation")
                plt.imshow(OrigImage)
                plt.show()

            return ADC_offset, ReadOutNoise_ADC, Signal_amp, 0

    if(verbose):
        plt.figure(figsize=(10, 5))
        plt.subplot(131)
        plt.imshow(OrigImage)
        plt.title("Offset = {}".format(offset))
        plt.subplot(132)
        plt.imshow(binary_mask)
        plt.title("signal mask - emitter density {:.3f}".format(emitter_density))
        plt.subplot(133)
        plt.imshow(noise_mask)
        plt.title("noise mask - SNR pred {:.3f}".format(Signal_amp / ADC_offset))
        plt.show()

    return ADC_offset, ReadOutNoise_ADC, Signal_amp, emitter_density

def project_01(im):
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val)/(max_val - min_val)

def project_01_ret_vals(im):
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val)/(max_val - min_val), min_val, max_val

def normalize_im(im, dmean, dstd):
    im = np.squeeze(im)
    return (im - dmean)/dstd

def conv_bn_relu(nb_filter, rk, ck, name):
    def f(input_tensor):
        conv = Conv2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),
                      padding="same", use_bias=False,
                      kernel_initializer="Orthogonal", name='conv-'+name)(input_tensor)
        conv_norm = BatchNormalization(name='BN-'+name)(conv)
        conv_norm_relu = Activation("relu", name='Relu-'+name)(conv_norm)
        return conv_norm_relu
    return f

def CNN(input_tensor, names):
    Features1 = conv_bn_relu(32,3,3,names+'F1')(input_tensor)
    pool1 = MaxPooling2D(pool_size=(2,2), name=names+'Pool1')(Features1)
    Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2), name=names+'Pool2')(Features2)
    Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2), name=names+'Pool3')(Features3)
    Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)
    up5 = UpSampling2D(size=(2, 2), name=names+'Upsample1')(Features4)
    Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)
    up6 = UpSampling2D(size=(2, 2), name=names+'Upsample2')(Features5)
    Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)
    up7 = UpSampling2D(size=(2, 2), name=names+'Upsample3')(Features6)
    Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)
    return Features7

def buildModel(input_dim, initial_learning_rate=0.001):
    input_ = Input(shape=input_dim)
    act_ = CNN(input_, 'CNN')
    density_pred = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding="same",
                           activation="linear", use_bias=False,
                           kernel_initializer="Orthogonal", name='Prediction')(act_)
    model = Model(inputs=input_, outputs=density_pred)
    opt = Adam(learning_rate=initial_learning_rate)
    model.compile(optimizer=opt, loss=L1L2loss(input_dim))
    return model

def CNN_upsample(input, upsampling_factor):
    # Encoder
    x = Conv2D(32, (3, 3), padding='same', name="F1")(input)
    x = BatchNormalization(name="BN_1")(x)
    x = Activation('relu',name="ReLU_1")(x)

    x = Conv2D(64, (3, 3), padding='same', name="F2")(x)
    x = BatchNormalization(name="BN_2")(x)
    x = Activation('relu', name="ReLU_2")(x)

    x = Conv2D(128, (3, 3), padding='same', name="F3")(x)
    x = BatchNormalization(name="BN_3")(x)
    x = Activation('relu', name="ReLU_3")(x)

    x = Conv2D(256, (3, 3), padding='same', name="F4")(x)
    x = BatchNormalization(name="BN_4")(x)
    x = Activation('relu', name="ReLU_4")(x)

    # Decoder
    x = Conv2D(128, (3, 3), padding='same', name="F5")(x)
    x = BatchNormalization(name="BN_5")(x)
    x = Activation('relu', name="ReLU_5")(x)

    x = Conv2D(64, (3, 3), padding='same', name="F6")(x)
    x = BatchNormalization(name="BN_6")(x)
    x = Activation('relu', name="ReLU_6")(x)

    for ind, scale in enumerate(range(int(np.log2(upsampling_factor)))):
        x = UpSampling2D(size=(2, 2), interpolation='bilinear', name="upsample_{}".format(ind+1))(x)
        x = Conv2D(32, (5, 5), padding='same', name="conv_upsample{}".format(ind+1))(x)
        x = BatchNormalization(name="BN_upsample{}".format(ind+1))(x)
        x = Activation('relu', name="ReLU_upsample{}".format(ind+1))(x)

    return x

def build_model_upsample(input_shape, lr=0.001, upsampling_factor=2):
    input_ = Input(shape=input_shape)
    act_ = CNN_upsample(input_, upsampling_factor)
    density_pred = Conv2D(1, kernel_size=(1, 1), strides=(1, 1), padding="same",
                                  activation="linear", use_bias = False,
                                  kernel_initializer="Orthogonal",name='Prediction')(act_)
    model = Model(inputs= input_, outputs=density_pred)
    opt = Adam(learning_rate=lr)
    model.compile(optimizer=opt, loss = custom_loss(input_shape))
    return model

def custom_loss(input_shape):
    def loss_fn(y_true, y_pred):
        heatmap_pred = tf.nn.conv2d(y_pred, gfilter, strides=1, padding='SAME')
        loss_heatmaps = tf.reduce_mean(tf.square(y_true - heatmap_pred))
        loss_spikes = tf.reduce_mean(tf.abs(y_pred))
        return loss_heatmaps + loss_spikes

    return loss_fn

def matlab_style_gauss2D(shape=(7,7), sigma=1):
    m, n = [(ss-1.)/2. for ss in shape]
    y, x = np.ogrid[-m:m+1, -n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h = h.astype(K.floatx())
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h = h * 2.0
    h = h.astype('float32')
    return h

# Expand the filter dimensions
psf_heatmap = matlab_style_gauss2D(shape=(7,7), sigma=1)
gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])

# Combined MSE + L1 loss
def L1L2loss(input_shape):
    def bump_mse(heatmap_true, spikes_pred):
        heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')
        loss_heatmaps = losses.mean_squared_error(heatmap_true, heatmap_pred)
        loss_spikes = losses.mean_absolute_error(spikes_pred, tf.zeros(input_shape))
        return loss_heatmaps + loss_spikes
    return bump_mse

def getPixelSizeTIFFmetadata(TIFFpath, display=False):
  with Image.open(TIFFpath) as img:
    meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}

  ResolutionUnit = meta_dict['ResolutionUnit'][0]
  width = meta_dict['ImageWidth'][0]
  height = meta_dict['ImageLength'][0]
  xResolution = meta_dict['XResolution'][0]
  if len(xResolution) == 1:
    xResolution = xResolution[0]
  elif len(xResolution) == 2:
    xResolution = xResolution[0]/xResolution[1]
  else:
    log('Image resolution not defined.')
    xResolution = 1

  if ResolutionUnit == 2:
    pixel_size = 0.025*1e9/xResolution
  elif ResolutionUnit == 3:
    pixel_size = 0.01*1e9/xResolution
  else:
    log('Resolution unit not defined. Assuming: um')
    pixel_size = 1e3/xResolution

  if display:
    log('Pixel size obtained from metadata: '+str(pixel_size)+' nm')
    log('Image size: '+str(width)+'x'+str(height))

  return (pixel_size, width, height)

def saveAsTIF(path, filename, array, pixel_size):
  if (array.dtype == np.uint16):
    mode = 'I;16'
  elif (array.dtype == np.uint32):
    mode = 'I'
  else:
    mode = 'F'

  if len(array.shape) == 2:
    im = Image.fromarray(array)
    im.save(os.path.join(path, filename+'.tif'),
                  mode=mode,
                  resolution_unit=3,
                  resolution=0.01*1e9/pixel_size)
  elif len(array.shape) == 3:
    imlist = []
    for frame in array:
      imlist.append(Image.fromarray(frame))
    imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,
                  append_images=imlist[1:],
                  mode=mode,
                  resolution_unit=3,
                  resolution=0.01*1e9/pixel_size)
  return

class Maximafinder(Layer):
    def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):
        super(Maximafinder, self).__init__(**kwargs)
        self.thresh = tf.constant(thresh, dtype=tf.float32)
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

    def build(self, input_shape):
        if self.use_local_avg:
          self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])
          self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])
          self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])

    def call(self, inputs):
        max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)
        cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)
        indices = tf.where(cond)
        bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]
        confidence = tf.gather_nd(inputs, indices)

        if self.use_local_avg:
          x_image = K.conv2d(inputs, self.kernel_x, padding='same')
          y_image = K.conv2d(inputs, self.kernel_y, padding='same')
          sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')
          confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)
          x_local = tf.math.divide(tf.gather_nd(x_image, indices), tf.gather_nd(sum_image, indices))
          y_local = tf.math.divide(tf.gather_nd(y_image, indices), tf.gather_nd(sum_image, indices))
          xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)
          yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)
        else:
          xind = tf.cast(xind, dtype=tf.float32)
          yind = tf.cast(yind, dtype=tf.float32)

        return bind, xind, yind, confidence

    def get_config(self):
        base_config = super(Maximafinder, self).get_config()
        config = {}
        return dict(list(base_config.items()) + list(config.items()))

def iter_tiff_frames(path):
    """Yield frames (float32 HxW) and return total count at the end."""
    with tiff.TiffFile(path) as tif:
        number_of_frames = len(tif.pages)
        for page in tif.pages:
            yield page.asarray().astype(np.float32)
    # Generator style; the caller should separately track count if needed.

def count_tiff_frames(path):
    with tiff.TiffFile(path) as tif:
        return len(tif.pages)

# --- ND2 helpers: detection, frame count, and per-frame iterator ---

def is_nd2(path: str) -> bool:
    try:
        import nd2
        return nd2.is_supported_file(path)
    except Exception:
        return path.lower().endswith(".nd2")


def count_nd2_frames(path: str) -> int:
    import nd2
    with nd2.ND2File(path) as f:
        # Preferred: number of sequence frames in the file
        try:
            return len(f.loop_indices)            # robust across T/Z/C acquisitions
        except Exception:
            # Fallback: product of looped axes
            sz = getattr(f, "sizes", {}) or {}
            prod = 1
            for ax in ("T", "Z", "C", "V"):       # V = positions/fields
                prod *= int(sz.get(ax, 1))
            return prod


def _nd2_to_2d(arr, channel=None):
    """
    Convert an ND2 frame (which may be 2D or 3D) to a 2D grayscale image.
    - If RGB (Y,X,3/4) -> take channel 0 (or specified)
    - If channels-first (C,Y,X) -> take channel 0 (or specified)
    - Otherwise squeeze or mean as a last resort.
    """
    import numpy as np
    a = np.asarray(arr)
    if a.ndim == 2:
        return a
    if a.ndim == 3:
        # (Y, X, C) RGB or multi-channel last
        if a.shape[-1] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[-1]) else 0
            return a[..., idx]
        # (C, Y, X) channels-first
        if a.shape[0] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[0]) else 0
            return a[idx, ...]
        # Unknown 3D (e.g., tiny Z), collapse politely
        return a.mean(axis=0)
    # Any other shape: squeeze to best-effort 2D
    a = a.squeeze()
    return a if a.ndim == 2 else a.reshape(a.shape[-2], a.shape[-1])

def iter_nd2_frames(path: str, channel: int | None = None):
    import nd2, numpy as np
    n = count_nd2_frames(path)
    with nd2.ND2File(path) as f:
        for i in range(n):
            fr = f.read_frame(i)
            fr2d = _nd2_to_2d(fr, channel=channel)
            yield fr2d.astype(np.float32, copy=False)   # <-- ensure float32

def getPixelSizeND2metadata(path, display=False):
    import nd2
    with nd2.ND2File(path) as f:
        vox_um = getattr(f, "voxel_size", None)   # (z, y, x) in microns, when available
        if vox_um is None:
            return None, None, None
        px_nm = vox_um[2] * 1e3                   # x pixel size in nm
        # Width/height available from attributes/shape
        try:
            h, w = f.shape[-2], f.shape[-1]       # (Y, X) at the end
        except Exception:
            h = w = None
        if display:
            print(f"Pixel size (ND2): {px_nm:.2f} nm | image ~ {w}x{h}")
        return px_nm, w, h

def list_files_multi(directory, extensions):
    exts = {('.' + e.lower()) for e in extensions}
    for f in os.listdir(directory):
        if os.path.splitext(f)[1].lower() in exts:
            yield f

def list_files(directory, extension):
  return (f for f in os.listdir(directory) if f.endswith('.' + extension))

log('--------------------------------')
log('AutoDS installation complete.')

if tf.test.gpu_device_name() == '':
  log('You do not have GPU access.')
  log('Did you change your runtime?')
  log('If the runtime settings are correct then GPU might not be allocated to your session.')
  log('Expect slow performance. To access GPU try reconnecting later.')
else:
  log('You have GPU access')

log('Tensorflow version is ' + str(tf.__version__))

MAX_FILE_GB = 5.0  # warn & skip when file is larger than this

def _is_oom(exc: BaseException) -> bool:
    msg = (str(exc) or "").upper()
    return (
        isinstance(exc, tf.errors.ResourceExhaustedError) or
        isinstance(exc, MemoryError) or
        "OUT OF MEMORY" in msg or "OOM" in msg or
        exc.__class__.__name__ in {"_ArrayMemoryError"}
    )

@contextmanager
def catch_oom(phase: str, detail: str = "", on_oom="continue"):
    """
    Wrap any memory-heavy block. Prints a friendly message on OOM and continues.
    on_oom: "continue" (default) just prints and returns; any other value re-raises.
    """
    try:
        yield
    except Exception as e:
        if _is_oom(e):
            print(f"\n⚠️  OOM while {phase}{(' - ' + detail) if detail else ''}.")
            print("   Tip: reduce chunk_size/batch_size/upsampling, or downsample input.")
            if isinstance(e, tf.errors.ResourceExhaustedError) and getattr(e, "message", None):
                print("   TensorFlow says:", e.message.splitlines()[0][:200])
            else:
                traceback.print_exc(limit=1, file=sys.stdout)
            if on_oom != "continue":
                raise
        else:
            # Non-OOM: re-raise so real bugs are visible
            raise

verbose = False
# ------------------------------- User input -------------------------------
Data_folder = ""  #@param {type:"string"}
Result_folder = ""  #@param {type:"string"}

threshold = 10 #@param {type:"number"}
neighborhood_size = 3 #@param {type:"integer"}
use_local_average = False #@param {type:"boolean"}

num_patches = 4 #@param {type:"number"}
overlap = 4
batch_size = 1 #@param {type:"integer"}

interpolate_based_on_imaging_parameters = True #@param {type:"boolean"}
get_pixel_size_from_file = False #@param {type:"boolean"}
pixel_size = 107 #@param {type:"number"}
wavelength = 715 #@param {type:"number"}
numerical_aperture = 1.49 #@param {type:"number"}

chunk_size = 10000 #@param {type:"number"}

psf_sigma_nm = 0.21 * wavelength / numerical_aperture
psf_sigma_pixels = psf_sigma_nm / pixel_size

if get_pixel_size_from_file:
  pixel_size = None

matfile = sio.loadmat(os.path.join(prediction_model_path, model_names[0], 'model_metadata.mat'))
try:
    model_wavelength = np.array(matfile['wavelength'].item())
except:
    model_wavelength = None
try:
    model_NA = np.array(matfile['numerical_aperture'].item())
except:
    model_NA = None
try:
    model_pixel_size = np.array(matfile['pixel_size'].item())
except:
    model_pixel_size = None

if os.path.isdir(Data_folder):
    # iterate both TIFF and ND2
    for filename in list_files_multi(Data_folder, extensions=['tif','tiff','nd2']):

        # Install nd2 reader only when needed (optional)
        if filename.lower().endswith('.nd2'):
            try:
                import nd2  # already installed?
            except Exception:
                # Jupyter-only magic; remove if not on Colab
                get_ipython().system('pip install -q nd2')

        in_path = os.path.join(Data_folder, filename)

        # --------- file size guard ----------
        try:
            file_size_gb = os.path.getsize(in_path) / 1e9
            if file_size_gb > MAX_FILE_GB:
                print(f"\n⚠️  {filename}: {file_size_gb:.2f} GB > {MAX_FILE_GB:.2f} GB.")
                print("   Video size is too big, please use Google Colab Pro or run locally.")
                continue
        except Exception:
            pass

        # --- Resolve pixel size if requested ---
        if get_pixel_size_from_file:
            if is_tiff(in_path):
                with catch_oom("reading TIFF pixel size", filename):
                    pixel_size, _, _ = getPixelSizeTIFFmetadata(in_path, True)
            elif is_nd2(in_path):
                with catch_oom("reading ND2 pixel size", filename):
                    px_nm, _, _ = getPixelSizeND2metadata(in_path, True)
                    pixel_size = px_nm if px_nm is not None else pixel_size  # leave unchanged if unknown

        # --- Common model params ---
        upsampling_factor = np.array(matfile['upsampling_factor']).item()
        try:
            L2_weighting_factor = np.array(matfile['Normalization factor']).item()
        except:
            L2_weighting_factor = 100

        patches_list = [[] for _ in model_names]
        patch_indices_list = [[] for _ in model_names]
        frame_numbers = [[] for _ in model_names]

        # --- Choose reader & frame count ---
        number_of_frames, frame_iter = None, None
        with catch_oom("opening stack", filename):
            if is_tiff(in_path):
                number_of_frames = count_tiff_frames(in_path)
                frame_iter = iter_tiff_frames(in_path)
                log(f'Loaded tiff stack with {number_of_frames} frames')
            elif is_nd2(in_path):
                number_of_frames = count_nd2_frames(in_path)
                frame_iter = iter_nd2_frames(in_path)
                log(f'Loaded ND2 stack with ~{number_of_frames} planes (T*Z*C)')
            else:
                log(f"Skipping unsupported file: {filename}")

        if frame_iter is None:
            print(f"⚠️  Skipping {filename} due to earlier error.")
            continue

        # Initialize sum image lazily
        sum_image = None

        log('Splitting stack to patches and selecting Deep-STORM model')
        for i, frame in enumerate(tqdm(frame_iter, total=number_of_frames if number_of_frames else None)):

            if sum_image is None:
                with catch_oom("allocating preview buffer", f"{filename} sum_image"):
                    sum_image = np.zeros_like(frame, dtype=np.float32)
                    # if allocation failed, sum_image stays None and we skip this file
                    if sum_image is None:
                        print(f"⚠️  Skipping {filename}: failed to allocate preview buffer.")
                        break

            # Keep a running average for preview
            with catch_oom("accumulating preview", f"{filename} frame {i}"):
                sum_image += frame.astype(np.float32) / max(1, number_of_frames or 1)

            # Interpolate to match trained model’s effective resolution  (OOM-prone)
            with catch_oom("interpolating frame", f"{filename} frame {i}"):
                if interpolate_based_on_imaging_parameters:
                    frame_i = interpolate_frames(
                        frame,
                        model_pixel_size, pixel_size,
                        model_wavelength, wavelength,
                        model_NA, numerical_aperture
                    )[0]
                else:
                    frame_i = frame

            if 'frame_i' not in locals():
                # interpolation OOM'd; skip this frame
                continue

            M, N = frame_i.shape

            # Background subtraction + standardization (dtype-safe)
            fproc = np.asarray(frame_i, dtype=np.float32)   # <-- ensure float32
            p35   = np.percentile(fproc, 35)                # float64 scalar is fine to subtract from float32 array
            fproc = fproc - p35
            fproc = fproc - fproc.min()

            # use float64 accumulators for mean/std for stability, then cast back stays float32
            fmean = fproc.mean(dtype=np.float64)
            fstd  = fproc.std(dtype=np.float64) + 1e-6
            fproc = (fproc - fmean) / fstd

            # Split into patches (may allocate moderately)
            with catch_oom("splitting into patches", f"{filename} frame {i}"):
                patches = split_image_to_patches(fproc, num_patches, overlap)
            if 'patches' not in locals():
                continue

            offset = fproc.mean()

            # Per-patch difficulty selection
            for m in range(num_patches):
                for n in range(num_patches):
                    down  = overlap if m == 0 else 0
                    up    = (M // num_patches) - overlap if m == num_patches - 1 else (M // num_patches)
                    left  = overlap if n == 0 else 0
                    right = (N // num_patches) - overlap if n == num_patches - 1 else (N // num_patches)

                    with catch_oom("extracting features", f"{filename} frame {i} patch ({m},{n})"):
                        outputs = extract_features_frame(
                            patches[m*num_patches+n][down:up, left:right],
                            pixel_size,
                            psf_sigma_pixels,
                            offset=offset,
                            verbose=verbose
                        )
                    if 'outputs' not in locals():
                        continue

                    curr_mean_noise, curr_std_noise, signal_amp, curr_emitter_density = outputs

                    if (signal_amp == 0 or curr_mean_noise == 0):
                        continue
                    if any(np.isnan(v) for v in (signal_amp, curr_mean_noise, curr_std_noise, curr_emitter_density)):
                        continue

                    difficulty_choice = ChooseNetByDifficulty_2025(curr_emitter_density, signal_amp/curr_mean_noise)
                    patches_list[difficulty_choice].append(patches[m*num_patches+n])
                    patch_indices_list[difficulty_choice].append(m*num_patches+n)
                    frame_numbers[difficulty_choice].append(i)

        # --------------------------------------------------------------------------
        # The rest of your pipeline below stays the SAME (with small guards)
        # --------------------------------------------------------------------------
        selected_model_hist = np.array([len(p) for p in patches_list], dtype=float)
        if HEADLESS_PREVIEW:
            with catch_oom("plotting model histogram", filename):
                plt.figure(figsize=(10, 6))
                plt.bar(np.arange(len(model_names)), selected_model_hist, width=0.8)
                plt.xticks(np.arange(len(model_names)), model_names)
                plt.xlabel('selected model')
                plt.ylabel('number of patches to be analyzed by this model')
                plt.show()

        # If nothing was collected, skip reconstruction gracefully
        if sum(len(p) for p in patches_list) == 0:
            print(f"ℹ️  No usable patches for {filename}; skipping reconstruction.")
            continue

        # M,N from last frame_i; safe because we had at least one patch
        patchwise_recon = np.zeros([M * upsampling_factor, N * upsampling_factor], dtype=np.float32)
        frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []

        log('Analyzing patches for each model')
        for model_num, model_name in enumerate(model_names):
            model_dir = os.path.join(prediction_model_path, model_name)
            if os.path.exists(model_dir):
                log(f"The {os.path.basename(model_dir)} model will be used.")
            else:
                log('!! WARNING: The chosen model does not exist !!')
                log('Please make sure you provide a valid model path before proceeding further.')

            if use_local_average:
                log('Using local averaging')

            if not os.path.exists(Result_folder):
                log('Result folder was created.')
                os.makedirs(Result_folder, exist_ok=True)

            if patches_list[model_num]:
                log("Reconstructing in chunks")
                total_chunks = (len(patches_list[model_num]) // chunk_size) + 1
                for chunk_num in tqdm(range(total_chunks)):
                    chunk_start = chunk_num * chunk_size
                    chunk_end = min((chunk_num + 1) * chunk_size, len(patches_list[model_num]))

                    if chunk_start >= chunk_end:
                        continue

                    with catch_oom("reconstructing chunk",
                                   detail=f"{model_name} [{chunk_start}:{chunk_end}] of {len(patches_list[model_num])}"):
                        pw_recon, loc_list = reconstruct_patches_2025(
                            patches_list[model_num][chunk_start:chunk_end],
                            patch_indices_list[model_num][chunk_start:chunk_end],
                            frame_numbers[model_num][chunk_start:chunk_end],
                            os.path.join(prediction_model_path, model_names[model_num], 'best_weights.h5'),
                            num_patches,
                            overlap * upsampling_factor,
                            number_of_frames,
                            threshold,
                            neighborhood_size=neighborhood_size,
                            use_local_avg=use_local_average,
                            upsampling_factor=upsampling_factor,
                            pixel_size=pixel_size,
                            batch_size=batch_size
                        )

                        # If OOM occurred inside, pw_recon may be undefined
                        if 'pw_recon' in locals() and pw_recon is not None:
                            frame_number_list += loc_list[0]
                            x_nm_list += loc_list[1]
                            y_nm_list += loc_list[2]
                            confidence_au_list += loc_list[3]

                            patchwise_recon[:M//num_patches*upsampling_factor*num_patches,
                                            :N//num_patches*upsampling_factor*num_patches] += pw_recon

        ext = '_avg' if use_local_average else '_max'
        base = os.path.splitext(filename)[0]

        # Save outputs (guarded)
        with catch_oom("saving outputs", base):
            with open(os.path.join(Result_folder, f'Localizations_{base}{ext}.csv'), "w", newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])
                sort_ind = np.argsort(frame_number_list)
                locs = list(zip(list(np.array(frame_number_list)[sort_ind]),
                                list(np.array(x_nm_list)[sort_ind]),
                                list(np.array(y_nm_list)[sort_ind]),
                                list(np.array(confidence_au_list)[sort_ind])))
                writer.writerows(locs)

            pw_recon_tif = np.copy(patchwise_recon)
            cap = np.percentile(pw_recon_tif, 99.5)
            pw_recon_tif[pw_recon_tif > cap] = cap
            saveAsTIF(Result_folder, f'Predicted_patchwise_{base}', pw_recon_tif, pixel_size/upsampling_factor)

        log('--------------------------------------------------------------------')
        log('---------------------------- Previews ------------------------------')
        log('--------------------------------------------------------------------')
        if HEADLESS_PREVIEW:
            with catch_oom("plotting previews", filename):
                fig, axes = plt.subplots(1, 3, figsize=(20,16))
                axes[0].axis('off'); axes[0].imshow(sum_image); axes[0].set_title('Original', fontsize=15)
                axes[1].axis('off'); axes[1].imshow(patchwise_recon); axes[1].set_title('Prediction', fontsize=15)
                axes[2].axis('off'); axes[2].imshow(np.clip(patchwise_recon,
                                                            np.percentile(patchwise_recon, 1),
                                                            np.percentile(patchwise_recon, 99)))
                axes[2].set_title('Normalized Prediction', fontsize=15)
                plt.show()

In [None]:
#@markdown Drift Correction

## **6.2 Drift correction**
# @markdown ##Data parameters
Loc_file_path = "" #@param {type:"string"}
# @markdown Provide information about original data. Get the info automatically from the raw data?
Get_info_from_file = False #@param {type:"boolean"}
# Loc_file_path = "/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv" #@param {type:"string"}
original_image_path = "" #@param {type:"string"}
# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)
image_width = 40 #@param {type:"integer"}
image_height = 40 #@param {type:"integer"}

# @markdown ##Drift correction parameters
visualization_pixel_size = 10 #@param {type:"number"}
number_of_bins = 10 #@param {type:"integer"}
polynomial_fit_degree = 4 #@param {type:"integer"}

# @markdown ##Saving parameters
save_path = "" #@param {type:"string"}

# Read the localizations in
LocData = pd.read_csv(Loc_file_path)

# Calculate a few variables
Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))
Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))
nFrames = LocData['frame'].max() + 1

x_max = max(LocData['x [nm]'])
y_max = max(LocData['y [nm]'])
image_size = (Mhr, Nhr)
n_locs = len(LocData.index)

print('Image size: '+str(image_size))
print('Number of frames in data: '+str(nFrames))
print('Number of localizations in data: '+str(n_locs))

blocksize = math.ceil(nFrames/number_of_bins)
print('Number of frames per block: '+str(blocksize))

blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()
xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)
yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)

# Preparing the Reference image
photon_array = np.ones(yc_array.shape[0])
sigma_array = np.ones(yc_array.shape[0])
ImagesRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size=image_size, pixel_size=visualization_pixel_size)
ImagesRef = np.flip(np.flip(ImagesRef, axis=0), axis=1)

xDrift = np.zeros(number_of_bins)
yDrift = np.zeros(number_of_bins)

filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]

with open(os.path.join(save_path, filename_no_extension+"_DriftCorrectionData.csv"), "w", newline='') as file:
      writer = csv.writer(file)

      # Write the header in the csv file
      writer.writerow(["Block #", "x-drift [nm]","y-drift [nm]"])

      for b in tqdm(range(number_of_bins)):
            blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()
            xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)
            yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)

            photon_array = np.ones(yc_array.shape[0])
            sigma_array = np.ones(yc_array.shape[0])
            ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size=image_size, pixel_size=visualization_pixel_size)

            XC = fftconvolve(gaussian_filter(ImagesRef, 2), gaussian_filter(ImageBlock, 2), mode='same')
            yDrift[b], xDrift[b] = estimate_drift_com_nm(ImagesRef, ImageBlock, visualization_pixel_size, sigma=1.0, patch_radius=np.min([Mhr, Nhr])//2)

            # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)
            # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)
            writer.writerow([str(b), str((xDrift[b]-xDrift[0])), str((yDrift[b]-yDrift[0]))])

print('--------------------------------------------------------------------')

print('Fitting drift data...')
bin_number = np.arange(number_of_bins)*blocksize + blocksize/2
xDrift = (xDrift-xDrift[0])
yDrift = (yDrift-yDrift[0])

xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)
yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)

xDriftFit = np.poly1d(xDriftCoeff)
yDriftFit = np.poly1d(yDriftCoeff)
bins = np.arange(nFrames)
xDriftInterpolated = xDriftFit(bins)
yDriftInterpolated = yDriftFit(bins)

# ------------------ Displaying the image results ------------------

plt.figure(figsize=(15,10))
plt.plot(bin_number,xDrift, 'r+', label='x-drift')
plt.plot(bin_number,yDrift, 'b+', label='y-drift')
plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')
plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')
plt.title('Cross-correlation estimated drift')
plt.ylabel('Drift [nm]')
plt.xlabel('Bin number/ Time point')
plt.legend()
plt.show()

# ------------------ Actual drift correction -------------------

print('Correcting localization data...')
xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)
yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)
frames = LocData['frame'].to_numpy(dtype=np.int32)

xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)
ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size=image_size, pixel_size=visualization_pixel_size)
ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size=image_size, pixel_size=visualization_pixel_size)

# ------------------ Displaying the imge results ------------------
fig, axs = plt.subplots(1, 2, figsize=(15, 7.5), sharex=True, sharey=True)
# Raw
axs[0].axis('off')
axs[0].imshow(np.log(ImageRaw + 1e-3), cmap='gray')
axs[0].set_title('Raw', fontsize=15)
# Corrected
axs[1].axis('off')
axs[1].imshow(np.log(ImageCorr + 1e-3), cmap='gray')
axs[1].set_title('Corrected', fontsize=15)
plt.show()

# ------------------ Table with info -------------------
driftCorrectedLocData = pd.DataFrame()
driftCorrectedLocData['frame'] = frames
driftCorrectedLocData['x [nm]'] = xc_array_Corr
driftCorrectedLocData['y [nm]'] = yc_array_Corr
driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']

driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))
print('-------------------------------')
print('Corrected localizations saved.')


# **Version log**
---
<font size = 4>**v1.0**: Initial implementation

<font size = 4>**v1.1**: Improving resource management. Automatic model loading.

<font size = 4>**v1.2**: ND2 compatability. Further improvement of memory management

---


#**Thank you for using AutoDS!**