<a href="https://colab.research.google.com/github/Itamar-Horowitz/real-time-AutoDS/blob/main/AutoDS/Realtime_AutoDS_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **AutoDS**

---

<font size = 4> Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).

<font size = 4> AutoDS is an extension of Deep-STORM automating the reconstruction process and aleviating the need in human intervension. This is done by automatic detection of the experimental condition in the analyzed videos and automatic selection of a Deep-STORM model out of a set of pre-trained model for the data processing.

<font size = 4> Additionally, AutoDS pipeline splits each input frame into patches and enables processing of different regions in the field-of-view with different models. This mechanism led to an improvment in the reconstruction quality beyond the capabilities of Deep-STORM.


# **Before getting started**
---
<font size = 4> This notebook contains the code required only for inference of SMLM data using a set of pre-trained Deep-STORM models. For model training please follow this [link](https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/blob/main/AutoDS/AutoDS_training.ipynb).

# **Run configuration**
---
<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the found localizations csv.

<font size = 4>**`threshold`:** This paramter determines threshold for local maxima finding. A higher `threshold` will result in less localizations. **DEFAULT: 10**

<font size = 4>**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**

<font size = 4>**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**

<font size = 4>**`num_patches`:** Determines the number of patches in each row and each column after splitting the frames to patches. The total number of patches will be num_patches<sup>2</sup>. **DEFAULT: 4**

<font size = 4>**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 1**

<font size = 4>**The following parameters are relevant only if `interpolate_based_on_imaging_parameters` is checked:**

<font size = 4> - **`pixel_size` [nm]:** the pixels size of the analyzed video. **DEFAULT: 107**

<font size = 4> - **`wavelength` [nm]:** the emission wavelength of the analyzed video. **DEFAULT: 715**

<font size = 4> - **`numerical_aperture`:** the optical setup numerical aperture of the analyzed video. **DEFAULT: 1.49**

<font size = 4> - **`chunk_size`:** determine the number of patches that will be analyzed in each prediction iteration. This parameter is used for managing compute resources in Google Colab. If you are facing crashes due to RAM memory limitation, decrease the number of patches per chunk. **DEFAULT: 10000**

# Download testing flie

In [14]:
Data_folder = "https://github.com/Itamar-Horowitz/real-time-AutoDS/tree/41c8f1af1c8e513cffc314e496424f642f0ddf92/dataset/TOM20_10nM/1" #@param {type:"string"}

# ============================================================================
# GITHUB DATA DOWNLOAD UTILITIES
# ============================================================================
def download_github_file(url, destination):
    """Download a single file from GitHub, handling Git LFS if needed"""
    # Convert GitHub web URL to raw content URL if needed
    if 'github.com' in url and '/blob/' in url:
        url = url.replace('github.com', 'raw.githubusercontent.com').replace('/blob/', '/')

    os.makedirs(os.path.dirname(destination), exist_ok=True)
    print(f"Downloading: {url}")

    # Add headers to avoid GitHub's HTML wrapper
    req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})

    with urllib.request.urlopen(req) as response:
        content = response.read()

        # Check if this is a Git LFS pointer file
        if b'version https://git-lfs.github.com/spec/' in content[:200]:
            print("  Detected Git LFS file, extracting download URL...")
            # Parse LFS pointer to get actual file URL
            content_str = content.decode('utf-8')
            for line in content_str.split('\n'):
                if line.startswith('oid sha256:'):
                    oid = line.split(':')[1].strip()
                    # Construct LFS download URL
                    # Extract user/repo from original URL
                    parts = url.split('/')
                    user = parts[3]
                    repo = parts[4]
                    lfs_url = f"https://media.githubusercontent.com/media/{user}/{repo}/master/{'/'.join(parts[6:])}"
                    print(f"  LFS URL: {lfs_url}")

                    # Download the actual file
                    req_lfs = urllib.request.Request(lfs_url, headers={'User-Agent': 'Mozilla/5.0'})
                    with urllib.request.urlopen(req_lfs) as lfs_response:
                        content = lfs_response.read()
                    break

        # Write content to file
        with open(destination, 'wb') as out_file:
            out_file.write(content)

    # Verify file was downloaded correctly
    file_size = os.path.getsize(destination)
    if file_size < 1000:  # Files smaller than 1KB are likely error pages
        with open(destination, 'rb') as f:
            content_check = f.read(100)
            if b'<!DOCTYPE' in content_check or b'<html' in content_check:
                raise ValueError(f"Downloaded HTML instead of binary file. URL may be incorrect.")

    print(f"Saved to: {destination} ({file_size / (1024*1024):.2f} MB)")


def download_tiff_files_fallback(user, repo, commit, dir_path, local_path):
    """Fallback method to download TIFF files when API fails"""
    os.makedirs(local_path, exist_ok=True)
    downloaded = []

    # Try common TIFF file numbering patterns
    for i in range(1, 100):  # Try up to 100 files
        for ext in ['.tif', '.tiff']:
            file_name = f"{i}{ext}"
            raw_url = f"https://raw.githubusercontent.com/{user}/{repo}/{commit}/{dir_path}/{file_name}"
            dest_path = os.path.join(local_path, file_name)

            try:
                urllib.request.urlretrieve(raw_url, dest_path)
                print(f"Downloaded: {file_name}")
                downloaded.append(dest_path)
                break  # Found file with this number, try next
            except:
                continue  # File doesn't exist, try next

        if i > 10 and len(downloaded) == 0:
            break  # Stop if first 10 attempts fail

    if len(downloaded) == 0:
        print("No files found with fallback method.")
    else:
        print(f"Downloaded {len(downloaded)} files using fallback method")

    return downloaded


def download_github_directory(repo_url, local_path, branch='main'):
    """
    Download all files from a GitHub directory
    Uses git clone with LFS support as primary method

    Args:
        repo_url: GitHub directory URL (e.g., https://github.com/user/repo/tree/branch/path/to/dir)
        local_path: Local directory to save files
        branch: Git branch name (default: 'main')
    """
    # Parse the GitHub URL
    parts = repo_url.split('github.com/')[-1].split('/')
    if len(parts) < 5:
        raise ValueError("Invalid GitHub directory URL")

    user = parts[0]
    repo = parts[1]

    # Find where the path starts (after 'tree' and branch/commit)
    if 'tree' in parts:
        tree_idx = parts.index('tree')
        path_parts = parts[tree_idx + 2:]  # Skip 'tree' and branch/commit
        dir_path = '/'.join(path_parts)
    else:
        dir_path = '/'.join(parts[3:])

    # Get the commit/branch from URL
    if 'tree' in parts:
        commit = parts[parts.index('tree') + 1]
    else:
        commit = branch

    # Method 2: Fall back to API-based download
    print("\nStart downloading data file")

    api_url = f"https://api.github.com/repos/{user}/{repo}/contents/{dir_path}?ref={commit}"

    try:
        req = urllib.request.Request(api_url, headers={'User-Agent': 'Mozilla/5.0'})
        with urllib.request.urlopen(req) as response:
            files_data = json.loads(response.read().decode())
    except Exception as e:
        return download_tiff_files_fallback(user, repo, commit, dir_path, local_path)

    os.makedirs(local_path, exist_ok=True)

    downloaded_files = []
    for item in files_data:
        if item['type'] == 'file':
            file_name = item['name']

            # Download TIFF, TIF, and ND2 files
            if file_name.lower().endswith(('.tif', '.tiff', '.nd2')):
                # Use raw.githubusercontent.com for binary files
                file_url = f"https://raw.githubusercontent.com/{user}/{repo}/{commit}/{dir_path}/{file_name}"
                dest_path = os.path.join(local_path, file_name)

                try:
                    download_github_file(file_url, dest_path, commit_or_branch=commit)
                    downloaded_files.append(dest_path)
                except Exception as e:
                    print(f"Failed to download {file_name}: {e}")
                    # Try using download_url from API as fallback
                    if 'download_url' in item and item['download_url']:
                        try:
                            download_github_file(item['download_url'], dest_path, commit_or_branch=commit)
                            downloaded_files.append(dest_path)
                        except Exception as e2:
                            print(f"Also failed with API URL: {e2}")

    return downloaded_files

# ============================================================================
# DOWNLOAD DATA FROM GITHUB
# ============================================================================
if Data_folder.startswith('http'):
    downloaded_files = download_github_directory(Data_folder, Data_folder)



Start downloading data file
No files found with fallback method.


# **V1: Original TensorFlow Version**


In [10]:
import os
import sys
import traceback
import urllib.request
from contextlib import contextmanager

import numpy as np
import tifffile as tiff
import torch
from PIL import Image
from PIL.TiffTags import TAGS

def log(*args, **kwargs):
    if not config.QUIET:
        print(*args, **kwargs)

def list_files_multi(directory, extensions):
    exts = {('.' + e.lower()) for e in extensions}
    for f in os.listdir(directory):
        if os.path.splitext(f)[1].lower() in exts:
            yield f

def _is_oom(exc: BaseException) -> bool:
    msg = (str(exc) or "").upper()
    return (
        isinstance(exc, torch.cuda.OutOfMemoryError) or
        isinstance(exc, MemoryError) or
        "OUT OF MEMORY" in msg or "OOM" in msg or
        exc.__class__.__name__ in {"_ArrayMemoryError", "OutOfMemoryError"}
    )

@contextmanager
def catch_oom(phase: str, detail: str = "", on_oom="continue"):
    """
    Wrap any memory-heavy block. Prints a friendly message on OOM and continues.
    on_oom: "continue" (default) just prints and returns; any other value re-raises.
    """
    try:
        yield
    except Exception as e:
        if _is_oom(e):
            print(f"\n⚠️  OOM while {phase}{(' - ' + detail) if detail else ''}.")
            print("   Tip: reduce chunk_size/batch_size/upsampling, or downsample input.")
            if isinstance(e, torch.cuda.OutOfMemoryError):
                # PyTorch OOM messages are in str(e) directly
                msg_line = str(e).splitlines()[0][:200]
                print("   PyTorch says:", msg_line)
            else:
                traceback.print_exc(limit=1, file=sys.stdout)
            if on_oom != "continue":
                raise
        else:
            # Non-OOM: re-raise so real bugs are visible
            raise

# ============================================================================
# 1. TIFF File Operations
# ============================================================================

def getPixelSizeTIFFmetadata(TIFFpath, display=False):
    """Extract pixel size from TIFF metadata"""
    with Image.open(TIFFpath) as img:
        meta_dict = {TAGS[key]: img.tag[key] for key in img.tag.keys()}

    ResolutionUnit = meta_dict['ResolutionUnit'][0]
    width = meta_dict['ImageWidth'][0]
    height = meta_dict['ImageLength'][0]
    xResolution = meta_dict['XResolution'][0]

    if len(xResolution) == 1:
        xResolution = xResolution[0]
    elif len(xResolution) == 2:
        xResolution = xResolution[0] / xResolution[1]
    else:
        print('Image resolution not defined.')
        xResolution = 1

    if ResolutionUnit == 2:
        pixel_size = 0.025 * 1e9 / xResolution
    elif ResolutionUnit == 3:
        pixel_size = 0.01 * 1e9 / xResolution
    else:
        print('Resolution unit not defined. Assuming: um')
        pixel_size = 1e3 / xResolution

    if display:
        print(f'Pixel size from metadata: {pixel_size} nm')
        print(f'Image size: {width}x{height}')

    return pixel_size, width, height

def saveAsTIF(path, filename, array, pixel_size):
    """Save array as TIFF with metadata"""
    if array.dtype == np.uint16:
        mode = 'I;16'
    elif array.dtype == np.uint32:
        mode = 'I'
    else:
        mode = 'F'

    if len(array.shape) == 2:
        im = Image.fromarray(array)
        im.save(os.path.join(path, filename + '.tif'),
               mode=mode,
               resolution_unit=3,
               resolution=0.01 * 1e9 / pixel_size)
    elif len(array.shape) == 3:
        imlist = []
        for frame in array:
            imlist.append(Image.fromarray(frame))
        imlist[0].save(os.path.join(path, filename + '.tif'),
                      save_all=True,
                      append_images=imlist[1:],
                      mode=mode,
                      resolution_unit=3,
                      resolution=0.01 * 1e9 / pixel_size)

def is_tiff(path):
    """Check if file is TIFF"""
    return path.lower().endswith(('.tif', '.tiff'))

def iter_tiff_frames(path):
    """Iterate over TIFF frames"""
    with tiff.TiffFile(path) as tif:
        for page in tif.pages:
            yield page.asarray().astype(np.float32)

def count_tiff_frames(path):
    """Count frames in TIFF file"""
    with tiff.TiffFile(path) as tif:
        return len(tif.pages)

# ============================================================================
# 2. ND2 File Operations
# ============================================================================

def is_nd2(path):
    """Check if file is ND2"""
    try:
        import nd2
        return nd2.is_supported_file(path)
    except Exception:
        return path.lower().endswith(".nd2")

def count_nd2_frames(path):
    """Count frames in ND2 file"""
    import nd2
    with nd2.ND2File(path) as f:
        try:
            return len(f.loop_indices)
        except Exception:
            sz = getattr(f, "sizes", {}) or {}
            prod = 1
            for ax in ("T", "Z", "C", "V"):
                prod *= int(sz.get(ax, 1))
            return prod

def _nd2_to_2d(arr, channel=None):
    """Convert ND2 frame to 2D"""
    a = np.asarray(arr)
    if a.ndim == 2:
        return a
    if a.ndim == 3:
        if a.shape[-1] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[-1]) else 0
            return a[..., idx]
        if a.shape[0] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[0]) else 0
            return a[idx, ...]
        return a.mean(axis=0)
    a = a.squeeze()
    return a if a.ndim == 2 else a.reshape(a.shape[-2], a.shape[-1])

def iter_nd2_frames(path, channel=None):
    """Iterate over ND2 frames"""
    import nd2
    n = count_nd2_frames(path)
    with nd2.ND2File(path) as f:
        for i in range(n):
            fr = f.read_frame(i)
            fr2d = _nd2_to_2d(fr, channel=channel)
            yield fr2d.astype(np.float32, copy=False)

def getPixelSizeND2metadata(path, display=False):
    """Extract pixel size from ND2 metadata"""
    import nd2
    with nd2.ND2File(path) as f:
        vox_um = getattr(f, "voxel_size", None)
        if vox_um is None:
            return None, None, None
        px_nm = vox_um[2] * 1e3
        try:
            h, w = f.shape[-2], f.shape[-1]
        except Exception:
            h = w = None
        if display:
            print(f"Pixel size (ND2): {px_nm:.2f} nm | image ~ {w}x{h}")
        return px_nm, w, h

# ============================================================================
# 3. Drift Correction Functions
# ============================================================================

def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):
    """Apply drift correction to localizations"""
    n_locs = xc_array.shape[0]
    xc_array_Corr = np.empty(n_locs)
    yc_array_Corr = np.empty(n_locs)

    for loc in range(n_locs):
        xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc] - 1]
        yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc] - 1]

    return xc_array_Corr, yc_array_Corr

def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size=(64, 64), pixel_size=100):
    """Convert localizations to histogram image"""
    w, h = image_size
    locImage = np.zeros(image_size)
    n_locs = len(xc_array)

    for e in range(n_locs):
        y_idx = int(max(min(round(yc_array[e] / pixel_size), w - 1), 0))
        x_idx = int(max(min(round(xc_array[e] / pixel_size), h - 1), 0))
        locImage[y_idx][x_idx] += 1

    return locImage

def estimate_drift_com_nm(img1, img2, pixel_size_nm, sigma=1.0, patch_radius=3):
    """Estimate drift using center of mass of cross-correlation"""
    from scipy.ndimage import gaussian_filter
    from scipy.signal import fftconvolve

    # Smooth images
    img1_smooth = gaussian_filter(img1.astype(np.float32), sigma=sigma)
    img2_smooth = gaussian_filter(img2.astype(np.float32), sigma=sigma)

    # Cross-correlation
    corr = fftconvolve(img1_smooth, img2_smooth, mode='same')

    # Center of image
    center_y, center_x = np.array(corr.shape) // 2

    # Crop around center
    y_min = max(0, center_y - patch_radius)
    y_max = min(corr.shape[0], center_y + patch_radius + 1)
    x_min = max(0, center_x - patch_radius)
    x_max = min(corr.shape[1], center_x + patch_radius + 1)

    patch = corr[y_min:y_max, x_min:x_max]

    # Center of mass
    y_grid, x_grid = np.meshgrid(
        np.arange(y_min, y_max), np.arange(x_min, x_max), indexing='ij'
    )

    total = np.sum(patch)
    if total == 0:
        return 0.0, 0.0

    y_com = np.sum(patch * y_grid) / total
    x_com = np.sum(patch * x_grid) / total

    # Drift in pixels
    dy_px = y_com - center_y
    dx_px = x_com - center_x

    if abs(dy_px) > patch_radius or abs(dx_px) > patch_radius:
        return 0.0, 0.0

    # Convert to nm
    dy_nm = dy_px * pixel_size_nm
    dx_nm = dx_px * pixel_size_nm

    return dy_nm, dx_nm

# ============================================================================
# 4. Model Download Utilities
# ============================================================================

def ensure_models(model_names, target_root="/content/AutoDS_models", model_manifest=None):
    if model_manifest is None:
        raise ValueError("model_manifest must be provided.")

    os.makedirs(target_root, exist_ok=True)

    for m in model_names:
        cfg = model_manifest[m]
        mdir = os.path.join(target_root, m)
        need_fetch = False

        req = cfg.get("contains", [])
        if not os.path.isdir(mdir):
            need_fetch = True
        else:
            for f in req:
                if not os.path.exists(os.path.join(mdir, f)):
                    need_fetch = True
                    break

        if not need_fetch:
            print(f"[models] found: {m}")
            continue

        print(f"[models] preparing: {m}")
        os.makedirs(mdir, exist_ok=True)

        if "file_urls" in cfg:
            file_urls = cfg["file_urls"]
            for fname, url in file_urls.items():
                dst = os.path.join(mdir, fname)
                print(f"[models] downloading: {url}")
                urllib.request.urlretrieve(url, dst)
        else:
            raise ValueError(f"Model {m} manifest must have 'file_urls'.")

        for f in req:
            if not os.path.exists(os.path.join(mdir, f)):
                raise FileNotFoundError(f"Model {m} missing required file: {f}")

        print(f"[models] ready: {m}")

    return target_root

import numpy as np
import scipy.optimize as opt
from numpy.lib.stride_tricks import sliding_window_view
from scipy.ndimage import gaussian_filter, zoom
from scipy.ndimage import gaussian_laplace, maximum_filter, binary_dilation
import torch
import torch.nn.functional as F

# ============================================================================
# 1. Image Preprocessing Functions
# ============================================================================

def normalize_im_01(im):
    """Normalize image to [0, 1]"""
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val) / (max_val - min_val)

def normalize_im_01_ret_vals(im):
    """Normalize and return normalization parameters"""
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val) / (max_val - min_val), min_val, max_val

def normalize_im(im, dmean, dstd):
    """Normalize image with given mean and std"""
    im = np.squeeze(im)
    return (im - dmean) / dstd

def subtract_smooth_background(im, sigma=3):
    """Subtract smoothed background"""
    return im - gaussian_filter(im, sigma)

def remove_zero_padding(image):
    """Remove zero padding from image"""
    image_array = np.array(image)
    non_zero_rows = np.where(image_array.sum(axis=1) != 0)
    non_zero_cols = np.where(image_array.sum(axis=0) != 0)
    cropped_image = image_array[non_zero_rows[0][0]:non_zero_rows[0][-1]+1,
                                non_zero_cols[0][0]:non_zero_cols[0][-1]+1]
    return cropped_image

# ============================================================================
# 2. Patch Splitting
# ============================================================================

def split_image_to_patches(img, num_patches, overlap):
    """
    Split image into overlapping patches

    Args:
        img: Input image (H, W)
        num_patches: Number of patches per dimension
        overlap: Overlap size in pixels

    Returns:
        List of patches
    """
    H, W = img.shape
    patch_h = H // num_patches
    patch_w = W // num_patches

    # Pad image for border patches
    padded_img = np.pad(img, ((overlap, overlap), (overlap, overlap)), mode='reflect')

    # Window shape including overlap
    window_shape = (patch_h + 2 * overlap, patch_w + 2 * overlap)

    # Create sliding window view
    patches_view = sliding_window_view(padded_img, window_shape)

    # Sample at regular intervals
    patches_array = patches_view[0::patch_h, 0::patch_w, :, :]

    # Flatten to list
    num_rows, num_cols, ph, pw = patches_array.shape
    patches_list = [patches_array[i, j].copy()
                   for i in range(num_rows)
                   for j in range(num_cols)]

    return patches_list

# ============================================================================
# 3. Interpolation and Scaling
# ============================================================================

def gaussian_interpolation_batch(data_batch, scale, sigma=1):
    """Apply Gaussian interpolation to batch of images"""
    upsampled_data_batch = []

    for data in data_batch:
        smoothed_data = gaussian_filter(data, sigma=sigma)
        upsampled_data = zoom(smoothed_data, scale, order=3)
        upsampled_data_batch.append(upsampled_data)

    return np.array(upsampled_data_batch)

def interpolate_frames(tiff_stack, model_pixel_size, current_pixel_size,
                      model_wavelength, current_wavelength,
                      model_NA, current_NA):
    """Interpolate frames to match model parameters"""
    # Set defaults
    if model_pixel_size is None:
        model_pixel_size = current_pixel_size
    if model_wavelength is None:
        model_wavelength = current_wavelength
    if model_NA is None:
        model_NA = current_NA
    if current_wavelength is None:
        current_wavelength = model_wavelength = 1
    if current_NA is None:
        current_NA = model_NA = 1

    if len(tiff_stack.shape) == 2:
        tiff_stack = tiff_stack[None, :, :]

    # Compute scaling ratio based on optical parameters
    scale_ratio_sq = ((0.21 * model_wavelength / model_NA) ** 2 -
                     (0.21 * current_wavelength / current_NA) ** 2)

    if scale_ratio_sq > 0:
        scale_ratio = np.sqrt(scale_ratio_sq) / model_pixel_size
        interpolated_stack = np.stack([
            gaussian_filter(tiff_stack[i], scale_ratio)
            for i in range(tiff_stack.shape[0])
        ])
    else:
        zoom_factors = (1,
                       model_pixel_size / current_pixel_size,
                       model_pixel_size / current_pixel_size)
        interpolated_stack = zoom(tiff_stack.astype(np.float32),
                                 zoom_factors, order=3)

    return interpolated_stack.astype(np.float32, copy=False)

# ============================================================================
# 4. Feature Extraction
# ============================================================================

def gauss2d(xy, offset, amp, x0, y0, sigma):
    """2D Gaussian function for fitting"""
    x, y = xy
    return offset + (amp * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2) -
                                  ((y - y0) ** 2) / (2 * sigma ** 2)))

def extract_features_frame(OrigImage, pixel_size, psf_sigma, offset=None, verbose=False):
    """
    Extract features from a single frame

    Returns:
        ADC_offset: Mean background
        ReadOutNoise_ADC: Std of background
        Signal_amp: Mean signal amplitude
        emitter_density: Density of emitters (per μm²)
    """
    M, N = OrigImage.shape

    # Subtract smooth background
    Image = OrigImage - gaussian_filter(OrigImage, sigma=5)

    # Check if SNR is sufficient
    if offset is not None:
        if (np.percentile(gaussian_filter(Image, 2), 99) < 2 * Image.mean() or
            np.percentile(OrigImage, 99) < 2 * offset):
            if verbose:
                print("SNR too low - ignoring patch")
            return np.mean(OrigImage), np.std(OrigImage), 0, 0

    # Laplacian of Gaussian for blob detection
    log_image = -gaussian_laplace(Image, sigma=psf_sigma)

    # Local maxima filtering
    neighborhood_size = 3
    local_max = (log_image == maximum_filter(log_image, size=neighborhood_size))

    # Intensity threshold
    amp_threshold = np.mean(Image) + 0.5 * (np.percentile(Image, 99) - np.mean(Image))
    pcntl_threshold = np.percentile(Image, 85)

    # Binary mask for emitters
    binary_mask = np.logical_and(local_max,
                                 Image > np.max([amp_threshold, pcntl_threshold]))

    # Dilate and create noise mask
    dilated_mask = binary_dilation(binary_mask, structure=np.ones((5, 5)))
    noise_mask = np.ones_like(binary_mask)
    noise_mask[dilated_mask] = 0

    if np.sum(binary_mask) > 0:
        ADC_offset = np.mean(OrigImage[noise_mask])
        ReadOutNoise_ADC = np.std(OrigImage[noise_mask])
        Signal_amp = np.mean(OrigImage[binary_mask == 1])
        emitter_density = (10 ** 6) * float(np.sum(binary_mask)) / (M * N * pixel_size ** 2)
    else:
        if verbose:
            print("Didn't find any emitters")
        return np.mean(OrigImage), np.std(OrigImage), 0, 0

    # Additional SNR check
    if Signal_amp / ADC_offset < 2.5:
        if emitter_density > 2:
            if verbose:
                print("SNR too low for emitter density estimation")
            return ADC_offset, ReadOutNoise_ADC, Signal_amp, 0

    return ADC_offset, ReadOutNoise_ADC, Signal_amp, emitter_density

# ============================================================================
# 5. Model Selection
# ============================================================================

def ChooseNetByDifficulty_2025(density, SNR):
    """ Choose network based on density and SNR """
    num_models = 4
    norm_density = np.max([np.min([int(np.round(2 * density)), num_models - 1]), 0])
    norm_SNR = num_models - 1 - np.max([np.min([SNR // 2, num_models - 1]), 0])
    return int(np.round((norm_SNR + norm_density) / 2))

# ============================================================================
# Module-level kernel cache (shared across all calls)
_kernel_cache = {}

def _get_gaussian_kernel(sigma, device):
    """Generate Gaussian kernel for smoothing"""
    key = f'gauss_{sigma}_{device}'
    if key not in _kernel_cache:
        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=device)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2))
        kernel = kernel / kernel.sum()
        _kernel_cache[key] = kernel.view(1, 1, kernel_size, kernel_size)
    return _kernel_cache[key]


def _get_log_kernel(sigma, device):
    """Generate Laplacian of Gaussian kernel for blob detection"""
    key = f'log_{sigma}_{device}'
    if key not in _kernel_cache:
        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=device)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        r2 = xx ** 2 + yy ** 2
        kernel = -(1 / (np.pi * sigma ** 4)) * (1 - r2 / (2 * sigma ** 2)) * torch.exp(-r2 / (2 * sigma ** 2))
        _kernel_cache[key] = kernel.view(1, 1, kernel_size, kernel_size)
    return _kernel_cache[key]


def percentile_batch(tensor, percentile):
    """Calculate percentile for batched tensors"""
    flat = tensor.flatten(1)
    result = torch.quantile(flat, percentile / 100.0, dim=1)
    return result

def extract_features_batch(patches_tensor, pixel_size, psf_sigma, offset_array=None,
                           verbose=False, device='cuda'):
    """Fully GPU-accelerated batch feature extraction"""
    B, H, W = patches_tensor.shape
    device = patches_tensor.device

    # Add channel dimension for conv operations: [B, 1, H, W]
    patches_4d = patches_tensor.unsqueeze(1)

    # 1. Gaussian filtering
    gauss_kernel = _get_gaussian_kernel(5, device)
    padding = gauss_kernel.shape[-1] // 2
    smooth_bg = F.conv2d(patches_4d, gauss_kernel, padding=padding)
    Image = patches_4d - smooth_bg # [B, 1, H, W]

    # 2. LoG filtering
    log_kernel = _get_log_kernel(psf_sigma, device)
    padding = log_kernel.shape[-1] // 2
    log_image = -F.conv2d(Image, log_kernel, padding=padding) # [B, 1, H, W]

    # 3. Local maxima
    local_max = F.max_pool2d(log_image, kernel_size=3, stride=1, padding=1) # [B, 1, H, W]

    # 4. Thresholding (all with size [B, 1, 1, 1])
    img_mean = Image.mean(dim=(2, 3), keepdim=True)
    img_99 = percentile_batch(Image.squeeze(1), 99).view(B, 1, 1, 1)
    img_85 = percentile_batch(Image.squeeze(1), 85).view(B, 1, 1, 1)
    threshold = torch.max(img_mean + 0.5 * (img_99 - img_mean), img_85)

    # 5. Binary masks (batch-wise)
    binary_mask = torch.logical_and(log_image == local_max, Image >= threshold)
    mask_float = binary_mask.float()
    dilated = F.max_pool2d(mask_float, kernel_size=5, stride=1, padding=2)
    noise_mask = (dilated < 0.5)

    # 6. PRE-COMPUTE SNR check data on GPU as a batch
    gauss_kernel_2 = _get_gaussian_kernel(2, device)
    padding_2 = (gauss_kernel_2.shape[-1] // 2)
    gauss_smooth = F.conv2d(Image, gauss_kernel_2, padding=padding_2)

    # Pre-compute percentiles on GPU (batch-wise)
    gauss_99 = percentile_batch(gauss_smooth.squeeze(1), 99) #[B]
    patch_99 = percentile_batch(patches_tensor, 99)  # [B]
    img_mean_flat = img_mean.squeeze()  # [B]

    # 7. Statistics on CPU
    patches_cpu = patches_tensor.cpu().numpy()
    binary_mask_cpu = binary_mask.squeeze(1).cpu().numpy()
    noise_mask_cpu = noise_mask.squeeze(1).cpu().numpy()

    # Move pre-computed values to CPU
    gauss_99_cpu = gauss_99.cpu().numpy()
    patch_99_cpu = patch_99.cpu().numpy()
    img_mean_cpu = img_mean_flat.cpu().numpy()

    results = []
    pixel_area = pixel_size * pixel_size

    for i in range(B):
        patch = patches_cpu[i]
        emitter_mask = binary_mask_cpu[i]
        noise_m = noise_mask_cpu[i]
        patch_offset = offset_array[i]

        if patch_offset is not None:
            if (gauss_99_cpu[i] < 2 * img_mean_cpu[i] or
                patch_99_cpu[i] < 2 * patch_offset):
                if verbose:
                    print(f"Patch {i}: SNR too low - ignoring patch")
                results.append((patch.mean(), patch.std(), 0.0, 0.0))
                continue

        num_emitters = emitter_mask.sum()
        if num_emitters == 0:
            if verbose:
                print(f"Patch {i}: Didn't find any emitters")
            results.append((patch.mean(), patch.std(), 0.0, 0.0))
            continue

        ADC_offset = patch[noise_m].mean()
        ReadOutNoise_ADC = patch[noise_m].std()
        Signal_amp = patch[emitter_mask].mean()
        emitter_density = 1e6 * float(num_emitters) / (H * W * pixel_area)

        # Additional SNR check
        if Signal_amp / (ADC_offset + 1e-8) < 2.5:
            if emitter_density > 2:
                if verbose:
                    print(f"Patch {i}: SNR too low for emitter density estimation")
                results.append((float(ADC_offset), float(ReadOutNoise_ADC),
                                float(Signal_amp), 0.0))
                continue

        results.append((float(ADC_offset), float(ReadOutNoise_ADC),
                        float(Signal_amp), float(emitter_density)))

    return results


def preprocess_frames_batch(frames_batch, device='cuda'):
    """GPU-accelerated batch preprocessing of frames"""
    B, H, W = frames_batch.shape

    # Calculate 35th percentile for each frame (on GPU)
    frames_flat = frames_batch.reshape(B, -1)
    p35 = torch.quantile(frames_flat, 0.35, dim=1, keepdim=True)
    p35 = p35.view(B, 1, 1)

    # Subtract 35th percentile
    frames_processed = frames_batch - p35

    # Subtract minimum
    frames_min = frames_processed.reshape(B, -1).min(dim=1, keepdim=True)[0]
    frames_min = frames_min.view(B, 1, 1)
    frames_processed = frames_processed - frames_min

    # Calculate mean and std for normalization
    frames_mean = frames_processed.reshape(B, -1).double().mean(dim=1).float()
    frames_std = frames_processed.reshape(B, -1).double().std(dim=1).float() + 1e-6
    frames_mean_batch = frames_mean.view(B, 1, 1)
    frames_std_batch = frames_std.view(B, 1, 1)

    # Normalize
    frames_processed = (frames_processed - frames_mean_batch) / frames_std_batch

    # Calculate offsets
    offsets = frames_processed.reshape(B, -1).mean(dim=1)

    return frames_processed, offsets


def interpolate_frames_batch(frames_batch, model_pixel_size, current_pixel_size,
                                  model_wavelength, current_wavelength,
                                  model_NA, current_NA, device='cuda'):
    """GPU-accelerated batch interpolation for multiple frames"""
    # Handle None values
    if model_pixel_size is None: model_pixel_size = current_pixel_size
    if model_wavelength is None: model_wavelength = current_wavelength
    if model_NA is None: model_NA = current_NA
    if current_wavelength is None: current_wavelength = model_wavelength = 1
    if current_NA is None: current_NA = model_NA = 1

    # Calculate scale ratio
    scale_ratio_sq = (0.21 * model_wavelength / model_NA) ** 2 - \
                     (0.21 * current_wavelength / current_NA) ** 2

    if scale_ratio_sq > 0:
        # Gaussian smoothing path
        scale_ratio = np.sqrt(scale_ratio_sq) / model_pixel_size
        kernel = _get_gaussian_kernel(scale_ratio, device)

        # Apply Gaussian filter to all frames at once
        frames_4d = frames_batch.unsqueeze(1)  # (B, 1, H, W)
        padding = kernel.shape[-1] // 2
        interpolated = F.conv2d(frames_4d, kernel, padding=padding).squeeze(1)
    else:
        # Zoom/resize path
        zoom_factor = model_pixel_size / current_pixel_size

        if zoom_factor != 1.0:
            # Use bilinear interpolation on GPU
            new_h = int(frames_batch.shape[1] * zoom_factor)
            new_w = int(frames_batch.shape[2] * zoom_factor)

            frames_4d = frames_batch.unsqueeze(1)
            interpolated = F.interpolate(frames_4d, size=(new_h, new_w),
                                        mode='bicubic', align_corners=False).squeeze(1)
        else:
            interpolated = frames_batch

    return interpolated

def split_image_to_patches_batch(img_batch, num_patches, overlap, device='cuda'):
    """ Split tensor of images into overlapping patches """
    # Handle both 2D and 3D input
    if img_batch.dim() == 2:
        img_batch = img_batch.unsqueeze(0)  # (H, W) -> (1, H, W)

    # Determine the non-overlapping patch size
    B, H, W = img_batch.shape
    patch_h = H // num_patches
    patch_w = W // num_patches

    # Pad image for border patches (reflection padding as in the original)
    padded = F.pad(img_batch.unsqueeze(1), # (B, 1, H, W)
                    (overlap, overlap, overlap, overlap),
                    mode='reflect').squeeze(1) # (B, H+2*overlap, W+2*overlap)

    # Calculate window shape including overlap
    window_h = patch_h + 2 * overlap
    window_w = patch_w + 2 * overlap

    # create sliding windows along height and then along width with the patch_h and patch_w as the step
    patches = padded.unfold(1, window_h, patch_h).unfold(2, window_w, patch_w)
    # Shape: (B, num_patches, num_patches, window_h, window_w)

    # Reshape to (B, num_patches * num_patches, window_h, window_w)
    # Flatten the 2D grid of patches for every frame (row-major order).
    B, num_rows, num_cols, ph, pw = patches.shape
    patches = patches.reshape(B, num_rows * num_cols, ph, pw)

    return patches

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# ============================================================================
# 1. Basic CNN Model (without upsampling)
# ============================================================================

class CNNModel(nn.Module):
    def __init__(self, in_channels=1):
        super(CNNModel, self).__init__()

        # Encoder
        self.features1 = ConvBNReLU(in_channels, 32, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.features2 = ConvBNReLU(32, 64, 3)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.features3 = ConvBNReLU(64, 128, 3)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.features4 = ConvBNReLU(128, 512, 3)

        # Decoder
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features5 = ConvBNReLU(512, 128, 3)

        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features6 = ConvBNReLU(128, 64, 3)

        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features7 = ConvBNReLU(64, 32, 3)

        # Prediction head
        self.prediction = nn.Conv2d(32, 1, 1, stride=1, padding=0, bias=False)
        nn.init.orthogonal_(self.prediction.weight)

    def forward(self, x):
        # Encoder
        x = self.features1(x)
        x = self.pool1(x)

        x = self.features2(x)
        x = self.pool2(x)

        x = self.features3(x)
        x = self.pool3(x)

        # Bottleneck
        x = self.features4(x)

        # Decoder
        x = self.upsample1(x)
        x = self.features5(x)

        x = self.upsample2(x)
        x = self.features6(x)

        x = self.upsample3(x)
        x = self.features7(x)

        # Prediction
        x = self.prediction(x)
        return x


# ============================================================================
# 2. CNN Building Blocks - optimized with fused Conv+BN+ReL operations
# ============================================================================

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None):
        super(ConvBNReLU, self).__init__()

        if padding is None:
            padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                              stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Initialize with Orthogonal (similar to Keras)
        nn.init.orthogonal_(self.conv.weight)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


# ============================================================================
# 3. CNN Model with Upsampling - optimized with fused Conv+BN+ReL
# ============================================================================

class CNNUpsample(nn.Module):
    def __init__(self, in_channels=1, upsampling_factor=8):
        super(CNNUpsample, self).__init__()
        self.upsampling_factor = upsampling_factor

        # Encoder with fused blocks
        self.conv_bn_relu1 = ConvBNReLU(in_channels, 32, 3, 1)
        self.conv_bn_relu2 = ConvBNReLU(32, 64, 3, 1)
        self.conv_bn_relu3 = ConvBNReLU(64, 128, 3, 1)
        self.conv_bn_relu4 = ConvBNReLU(128, 256, 3, 1)

        # Decoder with fused blocks
        self.conv_bn_relu5 = ConvBNReLU(256, 128, 3, 1)
        self.conv_bn_relu6 = ConvBNReLU(128, 64, 3, 1)

        # OPTIMIZED: Upsampling blocks with 3x3 kernels + fused Conv+BN+ReLU
        num_upsample_blocks = int(np.log2(upsampling_factor))
        self.upsample_blocks = nn.ModuleList()

        for i in range(num_upsample_blocks):
            in_ch = 64 if i == 0 else 32
            block = nn.ModuleDict({
                'upsample': nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                'conv_bn_relu': ConvBNReLU(in_ch, 32, 5, 1)
            })
            self.upsample_blocks.append(block)

        # Prediction head
        self.prediction = nn.Conv2d(32, 1, 1, stride=1, padding=0, bias=False)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        # Encoder
        x = self.conv_bn_relu1(x)
        x = self.conv_bn_relu2(x)
        x = self.conv_bn_relu3(x)
        x = self.conv_bn_relu4(x)

        # Decoder
        x = self.conv_bn_relu5(x)
        x = self.conv_bn_relu6(x)

        # Upsampling
        for block in self.upsample_blocks:
            x = block['upsample'](x)
            x = block['conv_bn_relu'](x)

        # Prediction
        x = self.prediction(x)
        return x


# ============================================================================
# 1. Gaussian Filter for Loss Computation
# ============================================================================

def matlab_style_gauss2D(shape=(7, 7), sigma=1):
    """Create 2D Gaussian kernel matching MATLAB style"""
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1, -n:n+1]
    h = np.exp(-(x*x + y*y) / (2. * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h = h * 2.0
    return h.astype(np.float32)

# Create Gaussian filter as a tensor
psf_heatmap = matlab_style_gauss2D(shape=(7, 7), sigma=1)
# Shape: [out_channels, in_channels, height, width] -> [1, 1, 7, 7]
gfilter = torch.from_numpy(psf_heatmap).view(1, 1, 7, 7)

# ============================================================================
# 2. Custom Loss Functions
# ============================================================================

class L1L2Loss(nn.Module):
    """Combined L1 + L2 loss with Gaussian filtering"""
    def __init__(self, input_shape):
        super(L1L2Loss, self).__init__()
        self.input_shape = input_shape
        # Register Gaussian filter as buffer (moves with model to GPU)
        self.register_buffer('gfilter', gfilter)

    def forward(self, spikes_pred, heatmap_true):
        # Apply Gaussian convolution to predictions
        heatmap_pred = F.conv2d(spikes_pred, self.gfilter, padding=3)

        # MSE loss on heatmaps
        loss_heatmaps = F.mse_loss(heatmap_pred, heatmap_true)

        # L1 loss on spikes (sparsity)
        loss_spikes = torch.mean(torch.abs(spikes_pred))

        return loss_heatmaps + loss_spikes

class CustomLoss(nn.Module):
    """Custom loss for upsampling model"""
    def __init__(self, input_shape):
        super(CustomLoss, self).__init__()
        self.input_shape = input_shape
        self.register_buffer('gfilter', gfilter)

    def forward(self, y_pred, y_true):
        # Apply Gaussian convolution
        heatmap_pred = F.conv2d(y_pred, self.gfilter, padding=3)

        # MSE on heatmaps
        loss_heatmaps = torch.mean((y_true - heatmap_pred) ** 2)

        # L1 on predictions (sparsity)
        loss_spikes = torch.mean(torch.abs(y_pred))

        return loss_heatmaps + loss_spikes

# ============================================================================
# 3. Maxima Finder Layer (Peak Detection)
# ============================================================================

class MaximaFinder(nn.Module):
    """Find local maxima in predicted density maps"""
    def __init__(self, thresh=0.1, neighborhood_size=3, use_local_avg=False):
        super(MaximaFinder, self).__init__()
        self.thresh = thresh
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

        if use_local_avg:
            # Sobel-like kernels for local averaging
            kernel_x = torch.tensor([[[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_y = torch.tensor([[[-1, -1, -1], [0, 0, 0], [1, 1, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_sum = torch.ones(1, 1, 3, 3, dtype=torch.float32)

            self.register_buffer('kernel_x', kernel_x)
            self.register_buffer('kernel_y', kernel_y)
            self.register_buffer('kernel_sum', kernel_sum)

    def forward(self, inputs):
        # Max pooling to find local maxima
        max_pool = F.max_pool2d(inputs, kernel_size=self.nhood,
                               stride=1, padding=self.nhood//2)

        # Condition: value is local max AND above threshold
        cond = (max_pool > self.thresh) & (max_pool == inputs)

        # Get indices where condition is True
        indices = torch.nonzero(cond, as_tuple=False)  # (N, 4): [batch, channel, y, x]

        bind = indices[:, 0]  # batch indices
        yind = indices[:, 2]  # y coordinates
        xind = indices[:, 3]  # x coordinates

        # Gather confidence values
        confidence = inputs[bind, indices[:, 1], yind, xind]

        # Convert to float for potential subpixel refinement
        xind = xind.float()
        yind = yind.float()

        # Subpixel refinement using local averaging
        if self.use_local_avg:
            # Ensure kernels match input dtype
            kernel_x = self.kernel_x.to(inputs.dtype)
            kernel_y = self.kernel_y.to(inputs.dtype)
            kernel_sum = self.kernel_sum.to(dtype=inputs.dtype)

            # Compute gradients
            # Sobel-like kernels for local averaging
            x_image = F.conv2d(inputs, kernel_x, padding=1)
            y_image = F.conv2d(inputs, kernel_y, padding=1)
            sum_image = F.conv2d(inputs, kernel_sum, padding=1)

            # Gather at detected locations
            gathered_sum = sum_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_x = x_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_y = y_image[bind, indices[:, 1], yind.long(), xind.long()]

            # Compute local offsets
            x_local = gathered_x / (gathered_sum + 1e-6)
            y_local = gathered_y / (gathered_sum + 1e-6)

            # Update positions and confidence
            xind = xind + x_local
            yind = yind + y_local
            confidence = gathered_sum

        return bind, xind, yind, confidence

# ============================================================================
# 6. Maxima Finder Layer (Peak Detection)
# ============================================================================

class MaximaFinder(nn.Module):
    """Find local maxima in predicted density maps"""
    def __init__(self, thresh=0.1, neighborhood_size=3, use_local_avg=False):
        super(MaximaFinder, self).__init__()
        self.thresh = thresh
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

        if use_local_avg:
            # Sobel-like kernels for local averaging
            kernel_x = torch.tensor([[[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_y = torch.tensor([[[-1, -1, -1], [0, 0, 0], [1, 1, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_sum = torch.ones(1, 1, 3, 3, dtype=torch.float32)

            self.register_buffer('kernel_x', kernel_x)
            self.register_buffer('kernel_y', kernel_y)
            self.register_buffer('kernel_sum', kernel_sum)

    def forward(self, inputs):
        # Max pooling to find local maxima
        max_pool = F.max_pool2d(inputs, kernel_size=self.nhood,
                               stride=1, padding=self.nhood//2)

        # Condition: value is local max AND above threshold
        cond = (max_pool > self.thresh) & (max_pool == inputs)

        # Get indices where condition is True
        indices = torch.nonzero(cond, as_tuple=False)  # (N, 4): [batch, channel, y, x]

        bind = indices[:, 0]  # batch indices
        yind = indices[:, 2]  # y coordinates
        xind = indices[:, 3]  # x coordinates

        # Gather confidence values
        confidence = inputs[bind, indices[:, 1], yind, xind]

        # Convert to float for potential subpixel refinement
        xind = xind.float()
        yind = yind.float()

        # Subpixel refinement using local averaging
        if self.use_local_avg:
            # Ensure kernels match input dtype
            kernel_x = self.kernel_x.to(inputs.dtype)
            kernel_y = self.kernel_y.to(inputs.dtype)
            kernel_sum = self.kernel_sum.to(dtype=inputs.dtype)

            # Compute gradients
            x_image = F.conv2d(inputs, kernel_x, padding=1)
            y_image = F.conv2d(inputs, kernel_y, padding=1)
            sum_image = F.conv2d(inputs, kernel_sum, padding=1)

            # Gather at detected locations
            gathered_sum = sum_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_x = x_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_y = y_image[bind, indices[:, 1], yind.long(), xind.long()]

            # Compute local offsets
            x_local = gathered_x / (gathered_sum + 1e-6)
            y_local = gathered_y / (gathered_sum + 1e-6)

            # Update positions and confidence
            xind = xind + x_local
            yind = yind + y_local
            confidence = gathered_sum

        return bind, xind, yind, confidence

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import time
import h5py
import re
import os


# ============================================================================
# 1. Model Builder Function
# ============================================================================

def build_model_upsample(input_shape, lr=0.001, upsampling_factor=8):
    """
    Build upsampling model for PyTorch

    Args:
        input_shape: Tuple (H, W, C) - note: will be converted to (C, H, W)
        lr: Learning rate
        upsampling_factor: Upsampling factor

    Returns:
        model: PyTorch model
        optimizer: Adam optimizer
        criterion: Loss function
    """
    from c_models_and_layers import CNNUpsample, CustomLoss

    # Convert from (H, W, C) to (C, H, W)
    in_channels = input_shape[2] if len(input_shape) == 3 else 1

    model = CNNUpsample(in_channels=in_channels,
                        upsampling_factor=upsampling_factor)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = CustomLoss(input_shape)

    return model, optimizer, criterion


# ============================================================================
# 2. Weight Loading - Support both PyTorch and Keras formats
# ============================================================================

def load_model_weights(model, weights_path, verbose=True):
    """
    Load model weights from either PyTorch (.pth) or Keras (.h5) format

    Args:
        model: PyTorch model
        weights_path: Path to weights file (.pth or .h5)
        verbose: Print loading progress
    """
    if weights_path.endswith('.pth'):
        load_pytorch_weights(model, weights_path, verbose=verbose)
    elif weights_path.endswith('.h5'):
        load_keras_weights_to_pytorch(model, weights_path, verbose=verbose)
    else:
        raise ValueError(f"Unsupported weights format: {weights_path}. "
                         f"Expected .pth or .h5 file")


def load_pytorch_weights(model, pth_path, verbose=True):
    """
    Load PyTorch native weights from .pth file

    Args:
        model: PyTorch model
        pth_path: Path to .pth weights file
        verbose: Print loading progress
    """
    if verbose:
        print(f"Loading PyTorch weights from {pth_path}")

    # Get device from model
    device = next(model.parameters()).device

    # Load state dict
    state_dict = torch.load(pth_path, map_location=device)

    # Load weights into model
    model.load_state_dict(state_dict)

    if verbose:
        print("✓ PyTorch weights loaded successfully!")


def load_keras_weights_to_pytorch(model, h5_path, verbose=True):
    """
    Load Keras weights from H5 file to PyTorch model

    Supports fused Conv+BN+ReLU blocks while maintaining compatibility.

    Args:
        model: PyTorch model (CNNUpsample with fused blocks)
        h5_path: Path to Keras H5 weights file
        verbose: Print loading progress
    """
    if verbose:
        print(f"Loading Keras weights from {h5_path}")

    # Get device from model
    device = next(model.parameters()).device

    with h5py.File(h5_path, 'r') as f:
        # Get all layer names from the H5 file
        if 'model_weights' in f:
            weight_group = f['model_weights']
        else:
            weight_group = f

        # Extract layer names
        if hasattr(weight_group, 'attrs') and 'layer_names' in weight_group.attrs:
            layer_names = [n.decode('utf8') if isinstance(n, bytes) else n
                           for n in weight_group.attrs['layer_names']]
        else:
            layer_names = list(weight_group.keys())

        if verbose:
            print(f"Found {len(layer_names)} layers in H5 file")

        # Create a dictionary to store weights
        keras_weights = {}

        for layer_name in layer_names:
            if layer_name not in weight_group:
                continue

            layer_group = weight_group[layer_name]

            if not hasattr(layer_group, 'keys'):
                continue

            # Get weight names for this layer
            if hasattr(layer_group, 'attrs') and 'weight_names' in layer_group.attrs:
                weight_names = [n.decode('utf8') if isinstance(n, bytes) else n
                                for n in layer_group.attrs['weight_names']]
            else:
                weight_names = list(layer_group.keys())

            # Extract weights
            layer_weights = {}
            for weight_name in weight_names:
                if '/' in weight_name:
                    weight_key = weight_name.split('/')[-1]
                else:
                    weight_key = weight_name

                try:
                    weight_value = layer_group[weight_name][()]
                    layer_weights[weight_key] = weight_value
                except:
                    try:
                        weight_value = layer_group[weight_key][()]
                        layer_weights[weight_key] = weight_value
                    except:
                        if verbose:
                            print(f"  Warning: Could not load {weight_name} from {layer_name}")

            if layer_weights:
                keras_weights[layer_name] = layer_weights

        if verbose:
            print(f"Extracted weights from {len(keras_weights)} layers")

        # Assign to PyTorch model with fused blocks
        _assign_weights_to_model(model, keras_weights, device, verbose=verbose)

    if verbose:
        print("✓ Keras weights loaded successfully!")


def _assign_weights_to_model(model, keras_weights, device, verbose=True):
    """Helper function to assign Keras weights to PyTorch model with fused blocks"""

    # Mapping from Keras layer names to PyTorch fused block names
    name_mapping = {
        'F1': 'conv_bn_relu1',
        'BN_1': 'conv_bn_relu1',
        'F2': 'conv_bn_relu2',
        'BN_2': 'conv_bn_relu2',
        'F3': 'conv_bn_relu3',
        'BN_3': 'conv_bn_relu3',
        'F4': 'conv_bn_relu4',
        'BN_4': 'conv_bn_relu4',
        'F5': 'conv_bn_relu5',
        'BN_5': 'conv_bn_relu5',
        'F6': 'conv_bn_relu6',
        'BN_6': 'conv_bn_relu6',
        'Prediction': 'prediction',
    }

    model_dict = dict(model.named_modules())
    loaded_count = 0

    # Load encoder and decoder layers (now fused blocks)
    for keras_name, pytorch_name in name_mapping.items():
        if keras_name not in keras_weights:
            continue

        if pytorch_name not in model_dict:
            continue

        module = model_dict[pytorch_name]
        weights = keras_weights[keras_name]

        # Check if this is a fused ConvBNReLU block
        if hasattr(module, 'conv') and hasattr(module, 'bn'):
            # This is a fused block - load into conv and bn sub-modules

            # Load Conv2d weights
            if 'kernel:0' in weights:
                kernel = weights['kernel:0']
                kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                module.conv.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                loaded_count += 1
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name}.conv (Conv2d)")

            # Load BatchNorm weights
            if 'gamma:0' in weights:
                module.bn.weight.data = torch.from_numpy(weights['gamma:0']).float().to(device)
            if 'beta:0' in weights:
                module.bn.bias.data = torch.from_numpy(weights['beta:0']).float().to(device)
            if 'moving_mean:0' in weights:
                module.bn.running_mean.data = torch.from_numpy(weights['moving_mean:0']).float().to(device)
            if 'moving_variance:0' in weights:
                module.bn.running_var.data = torch.from_numpy(weights['moving_variance:0']).float().to(device)

            if any(k in weights for k in ['gamma:0', 'beta:0']):
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name}.bn (BatchNorm)")

        # Load prediction layer (not fused)
        elif isinstance(module, nn.Conv2d):
            if 'kernel:0' in weights:
                kernel = weights['kernel:0']
                kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                module.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                loaded_count += 1
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name} (Conv2d)")

            if 'bias:0' in weights and module.bias is not None:
                bias = weights['bias:0']
                module.bias.data = torch.from_numpy(bias).float().to(device)

    # Load upsampling blocks (now with fused conv_bn_relu)
    for keras_name in keras_weights.keys():
        if 'conv_upsample' in keras_name or 'BN_upsample' in keras_name:
            match = re.search(r'(\d+)', keras_name)
            if match:
                idx = int(match.group(1)) - 1

                if idx >= len(model.upsample_blocks):
                    continue

                weights = keras_weights[keras_name]

                if 'conv_upsample' in keras_name:
                    # Access the fused block's conv layer
                    fused_block = model.upsample_blocks[idx]['conv_bn_relu']

                    if 'kernel:0' in weights and hasattr(fused_block, 'conv'):
                        kernel = weights['kernel:0']
                        kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                        fused_block.conv.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                        loaded_count += 1
                        if verbose:
                            print(f"  ✓ Loaded {keras_name} -> upsample_blocks[{idx}]['conv_bn_relu'].conv")

                elif 'BN_upsample' in keras_name:
                    # Access the fused block's bn layer
                    fused_block = model.upsample_blocks[idx]['conv_bn_relu']

                    if hasattr(fused_block, 'bn'):
                        if 'gamma:0' in weights:
                            fused_block.bn.weight.data = torch.from_numpy(weights['gamma:0']).float().to(device)
                        if 'beta:0' in weights:
                            fused_block.bn.bias.data = torch.from_numpy(weights['beta:0']).float().to(device)
                        if 'moving_mean:0' in weights:
                            fused_block.bn.running_mean.data = torch.from_numpy(weights['moving_mean:0']).float().to(
                                device)
                        if 'moving_variance:0' in weights:
                            fused_block.bn.running_var.data = torch.from_numpy(weights['moving_variance:0']).float().to(
                                device)
                        loaded_count += 1
                        if verbose:
                            print(f"  ✓ Loaded {keras_name} -> upsample_blocks[{idx}]['conv_bn_relu'].bn")

    if verbose:
        print(f"\n✓ Successfully loaded {loaded_count} layer weights")


# ============================================================================
# 3. Main Reconstruction Function with Global Profiling
# ============================================================================
def reconstruct_patches_2025_pytorch(
        patches, patch_indices, frame_numbers,
        model_num,
        num_patches, overlap,
        number_of_frames, threshold, neighborhood_size=3,
        use_local_avg=True, upsampling_factor=8,
        pixel_size=233, batch_size=32, L2_weighting_factor=100,
        profiler=None, precision_mode="float32", use_metadata = False):

    pixel_size_hr = pixel_size / upsampling_factor

    # Get device and precision mode from cache
    device = get_device()

    patches = patches.float().to(device)

     # Convert patches to appropriate precision
    if isinstance(patches, np.ndarray):
        patches = torch.from_numpy(patches).float().to(device)
    else:
        patches = patches.to(device)

    # Apply precision conversion
    if precision_mode == 'fp16' and device.type == 'cuda':
        patches = patches.half()
    elif precision_mode == 'fp8' and device.type == 'cuda':
        patches = patches.to(dtype=torch.float8_e4m3fn)

    if patches.ndim == 2:
        patches = patches.unsqueeze(0)  # Ensure 3D shape
    K_frames, M, N = patches.shape

    # Determine dimensions of each predicted (cropped) patch
    upsampled_patch_h = M * upsampling_factor - 2 * overlap
    upsampled_patch_w = N * upsampling_factor - 2 * overlap

    # Create full image tensor on GPU
    #dtype = torch.float16 if (precision_mode == "fp16" and device.type == 'cuda') else torch.float32
    reconstructed_image = torch.zeros((upsampled_patch_h * num_patches, upsampled_patch_w * num_patches),
                                      dtype=torch.float32, device=device)

    # Prepare lists for detections
    recon_xind, recon_yind, frame_index, confidence_list = [], [], [], []

    # Store predicted patches for each input patch
    all_predicted_patches = []

    with torch.cuda.device(0):
        # Get model from cache
        model = get_model(model_num)
        model.eval()

        # Create the post-processing layer
        max_layer = MaximaFinder(threshold, neighborhood_size, use_local_avg).to(device)

        '''
        # Convert maxima finder to appropriate precision
        if precision_mode == 'fp16' and device.type == 'cuda':
            max_layer = max_layer.half()
        elif precision_mode == 'fp8' and device.type == 'cuda':
            max_layer = max_layer.to(dtype=torch.float8_e4m3fn)
        '''

        # Process in batches
        n_batches = int(np.ceil(K_frames / batch_size))

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(K_frames, start_idx + batch_size)
            nF = end_idx - start_idx

            # --- Move input batch to GPU ---
            batch_imgs = patches[start_idx:end_idx].to(device)  # Shape: (nF, M, N)

            # add channel dim to match conv2D
            batch_imgs = batch_imgs.unsqueeze(1)  # Shape: (nF, 1, M, N)

            # --- Run prediction on GPU ---
            profiler.start_timer("model forward")
            with torch.no_grad():
                if precision_mode == 'fp16' and device.type == 'cuda':
                    # Use automatic mixed precision for FP16
                    with torch.amp.autocast('cuda', dtype=torch.float16):
                        predicted_density = model(batch_imgs)

                elif precision_mode == 'int8' and device.type == 'cuda':
                    # INT8 weight-only quantization doesn't need special autocast
                    # The model handles quantization internally
                    predicted_density = model(batch_imgs)

                else:
                    # Float32 with optional autocast for better performance
                    with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
                        predicted_density = model(batch_imgs)

            profiler.stop_timer("model forward")

            # Post-processing
            predicted_density = torch.relu(predicted_density - 0.5)

            # Crop off extra overlap
            cropped_pred = predicted_density[:, 0, overlap:-overlap, overlap:-overlap]

            profiler.start_timer("localizations detection")
            # --- Post-processing on GPU ---
            # Maxima detection
            bind, xind, yind, conf = max_layer(predicted_density[:, :, overlap:-overlap, overlap:-overlap])

            # Convert tensors to NumPy (only when needed)
            bind_np = bind.cpu().numpy()
            xind_np = xind.cpu().numpy()
            yind_np = yind.cpu().numpy()
            conf_np = conf.cpu().numpy() / L2_weighting_factor
            profiler.stop_timer("localizations detection")

            profiler.start_timer("reconstruction image building")
            # --- Place each patch in reconstructed image ---
            for i in range(nF):
                p_ind = patch_indices[start_idx + i]
                y1 = upsampled_patch_h * (p_ind // num_patches)
                x1 = upsampled_patch_w * (p_ind % num_patches)

                # Use PyTorch addition instead of NumPy
                reconstructed_image[y1:y1 + upsampled_patch_h,
                x1:x1 + upsampled_patch_w].add_(cropped_pred[i] / number_of_frames)

                # Collect detections (CPU operations)
                det_idx = np.where(bind_np == i)[0]
                if det_idx.size:
                    recon_xind.extend((x1 + xind_np[det_idx]).tolist())
                    recon_yind.extend((y1 + yind_np[det_idx]).tolist())
                    frame_index.extend([frame_numbers[start_idx + i] + 1] * det_idx.size)
                    confidence_list.extend(conf_np[det_idx].tolist())

                all_predicted_patches.append(cropped_pred[i].cpu().numpy())

            profiler.stop_timer("reconstruction image building")

    # Convert coordinates to physical units
    xind_final = (np.array(recon_xind) * pixel_size_hr).tolist()
    yind_final = (np.array(recon_yind) * pixel_size_hr).tolist()

    # Return reconstructed image, localizations, and predicted patches
    return reconstructed_image, [frame_index, xind_final, yind_final, confidence_list], all_predicted_patches


# ============================================================================
# 4. Weight Validation Function
# ============================================================================

def validate_model_weights(model, verbose=True):
    """
    Validate that model weights are loaded correctly

    Args:
        model: PyTorch model with loaded weights
        verbose: Print validation details

    Returns:
        bool: True if weights appear valid
    """
    if verbose:
        print("\n" + "=" * 70)
        print("VALIDATING MODEL WEIGHTS")
        print("=" * 70)

    issues = []

    # Check encoder/decoder fused blocks
    for i in range(1, 7):
        block_name = f'conv_bn_relu{i}'
        if hasattr(model, block_name):
            block = getattr(model, block_name)

            # Check conv weights
            conv_weights = block.conv.weight.data
            if torch.all(conv_weights == 0):
                issues.append(f"{block_name}.conv weights are all zeros")
            elif torch.isnan(conv_weights).any():
                issues.append(f"{block_name}.conv weights contain NaN")

            # Check BN parameters
            if torch.all(block.bn.weight.data == 1) and torch.all(block.bn.bias.data == 0):
                issues.append(f"{block_name}.bn parameters are uninitialized (gamma=1, beta=0)")

            if verbose:
                print(f"  {block_name}.conv: shape={tuple(conv_weights.shape)}, "
                      f"mean={conv_weights.mean().item():.6f}, std={conv_weights.std().item():.6f}")
                print(f"  {block_name}.bn: gamma_mean={block.bn.weight.mean().item():.6f}, "
                      f"beta_mean={block.bn.bias.mean().item():.6f}")

    # Check upsampling blocks
    if verbose:
        print(f"\n  Upsampling blocks: {len(model.upsample_blocks)} blocks")

    for idx, block_dict in enumerate(model.upsample_blocks):
        fused_block = block_dict['conv_bn_relu']

        conv_weights = fused_block.conv.weight.data
        expected_kernel_size = 5
        actual_kernel_size = conv_weights.shape[2]

        if actual_kernel_size != expected_kernel_size:
            issues.append(f"upsample_blocks[{idx}] has {actual_kernel_size}x{actual_kernel_size} kernel, "
                          f"expected {expected_kernel_size}x{expected_kernel_size}")

        if torch.all(conv_weights == 0):
            issues.append(f"upsample_blocks[{idx}].conv weights are all zeros")

        if verbose:
            print(f"  upsample_blocks[{idx}].conv: shape={tuple(conv_weights.shape)}, "
                  f"kernel_size={actual_kernel_size}x{actual_kernel_size}, "
                  f"mean={conv_weights.mean().item():.6f}")

    # Check prediction layer
    pred_weights = model.prediction.weight.data
    if torch.all(pred_weights == 0):
        issues.append("prediction layer weights are all zeros")

    if verbose:
        print(f"\n  prediction: shape={tuple(pred_weights.shape)}, "
              f"mean={pred_weights.mean().item():.6f}")

    # Report results
    if verbose:
        print("\n" + "=" * 70)
        if issues:
            print("⚠️ VALIDATION WARNINGS:")
            for issue in issues:
                print(f"  - {issue}")
        else:
            print("✓ ALL WEIGHTS VALIDATED SUCCESSFULLY")
        print("=" * 70)

import time
import torch
from collections import defaultdict
from typing import Dict, List, Optional


class timing_profiler:
    def __init__(self, enabled=False):
        self.enabled = enabled
        self.accu_timing: Dict[str, List[float]] = defaultdict(list)
        self.active_timers: Dict[str, float] = {}

    def start_timer(self, name):
        if not self.enabled:
            return

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        self.active_timers[name] = time.perf_counter()

    def stop_timer(self, name):
        if not self.enabled:
            return

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        run_time = time.perf_counter() - self.active_timers[name]
        self.accu_timing[name].append(run_time)
        del self.active_timers[name]

    def get_stats(self, name):
        times = self.accu_timing[name]
        return {
            'total': sum(times),
            'average': sum(times) / len(times),
            'count': len(times),
            'min': min(times),
            'max': max(times)
        }

    def print_timing_summary(self):
        if not self.enabled:
            return

        #print("\n" + "=" * 100)
        #print("TIMING SUMMARY")
        #print("=" * 100)

        sub_sections_timers = {}

        for name, times in self.accu_timing.items():
            sub_sections_timers[name] = times

        # Print reconstruction section
        self._print_section(sub_sections_timers, "")

    def _print_section(self, timers, prefix):
        # Group by hierarchy level
        hierarchy = {}
        for name, times in timers.items():
            # Remove prefix
            relative_name = name[len(prefix) + 1:] if name.startswith(prefix + '.') else name
            hierarchy[relative_name] = times

        if not hierarchy:
            return

        if 'total' in hierarchy:
            total_time = sum(hierarchy['total'])
        else:
            total_time = sum(sum(times) for times in hierarchy.values())
        print("-" * 84)
        print(f"{'Step':<40} {'total time':<12} {'avg (per call)':<8} {'calls':<8} {'% of total':<12}")
        print("-" * 84)

        # Sort by total time (descending)
        sorted_items = sorted(hierarchy.items(), key=lambda x: sum(x[1]), reverse=True)

        for section_name, times in sorted_items:
            total = sum(times)
            avg_per_call = (total / len(times))
            count = len(times)
            percentage = (total / total_time * 100) if total_time > 0 else 0

            print(f"{section_name:<40} {total:>11.3f} {avg_per_call:>11.2f} {count:>7} {percentage:>10.1f}%")

    def reset(self):
        self.accu_timing.clear()
        self.active_timers.clear()




import torch
import torch.nn as nn
import os
from torchao.quantization import quantize_, Int8WeightOnlyConfig

# ============================================================================
# Global state - simple module-level variables
# ============================================================================

_models = {}  # Dictionary to store loaded models
_device = None  # Device (CPU or CUDA)
_is_initialized = False  # Track if we've loaded models
_precision_mode = "float32"  # Track current precision mode


# ============================================================================
# Simple functions to manage the cache
# ============================================================================

def initialize_model_cache(config, upsampling_factor, device=None,
                           use_pytorch_weights=False, precision_mode="float32"):
    """
    Load all models once and store them in memory

    Args:
        config: Configuration object with model paths and names
        upsampling_factor: Upsampling factor for the models
        device: Device to load models on (CPU or CUDA)
        use_pytorch_weights: If True, load .pth weights. If False, load .h5 weights
        precision_mode: "float32", "fp16", or "int8"
    """
    global _models, _device, _is_initialized, _precision_mode

    # Skip if already initialized
    if _is_initialized:
        #print("⚠️ Models already loaded, skipping initialization")
        return

    # Setup device
    if device is None:
        _device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        _device = device

    # Store precision mode
    _precision_mode = precision_mode

    #print("\n" + "=" * 70)
    #print("LOADING ALL MODELS INTO CACHE")
    #print("=" * 70)
    #print(f"Device: {_device}")
    #print(f"Upsampling Factor: {upsampling_factor}")
    #print(f"Weight Format: {'PyTorch (.pth)' if use_pytorch_weights else 'Keras (.h5)'}")
    #print(f"Precision Mode: {precision_mode}")
    #print(f"Number of models: {len(config.model_names)}")

    # Import model classes
    from c_models_and_layers import CNNUpsample
    from d_reconstruction import load_model_weights

    # Determine weight file extension
    weight_extension = 'best_weights.pth' if use_pytorch_weights else 'best_weights.h5'
    weight_type = "PyTorch" if use_pytorch_weights else "Keras"

    # Load each model
    for model_num, model_name in enumerate(config.model_names):
        model_path = os.path.join(
            config.prediction_model_path,
            model_name,
            weight_extension
        )

        if not os.path.exists(model_path):
            raise FileNotFoundError(
                f"Model weights not found: {model_path}\n"
                f"Expected {weight_type} weights for model: {model_name}"
            )

        #print(f"\nLoading model {model_num + 1}/{len(config.model_names)}: {model_name} ({weight_type} weights)")

        # Create model
        model = CNNUpsample(in_channels=1, upsampling_factor=upsampling_factor)
        model = model.to(_device)
        model.eval()

        # Load weights
        load_model_weights(model, model_path, verbose=False)

        # Apply precision mode
        if precision_mode == "fp16":
            if _device.type == 'cuda':
                model = model.half()  # Convert to FP16
                print(f"→ Converted to FP16 (half precision)")
            #else:
            #    print(f"⚠️ FP16 requested but CUDA not available, using float32")

        elif precision_mode == 'int8':
            if _device.type == 'cuda':
                quantize_(model, Int8WeightOnlyConfig())
                print("→ Converted to INT8 quantization")

        elif precision_mode == "float32":
            # Keep as float32 (default)
            print(f"→ Using float32 (full precision)")

        ## Optimize for inference - channels last
        #if torch.cuda.is_available() and precision_mode in ['float32', 'fp16']:
        #    try:
        #        model = model.to(memory_format=torch.channels_last)
        #    except:
        #        pass

        # Store in cache
        _models[model_num] = model

        # Print memory usage
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024 ** 2
            print(f"✓ Loaded {model_name} (GPU Memory: {allocated:.1f} MB)")
        else:
            print(f"  ✓ Loaded {model_name}")

    _is_initialized = True

    print(f"✓ ALL {len(_models)} models loaded and cached")

    if torch.cuda.is_available():
        total_allocated = torch.cuda.memory_allocated() / 1024 ** 2
        #print(f"Total GPU Memory Used: {total_allocated:.1f} MB")
        #print("=" * 70)


def get_model(model_num):
    """Get a cached model by its number"""
    if not _is_initialized:
        raise RuntimeError("Models not loaded. Call initialize_model_cache() first.")

    if model_num not in _models:
        raise KeyError(f"Model {model_num} not found. Available: {list(_models.keys())}")

    return _models[model_num]


def get_device():
    """Get the device being used (CPU or CUDA)"""
    if _device is None:
        raise RuntimeError("Model cache not initialized. Call initialize_model_cache() first.")

    return _device


def get_precision_mode():
    """Get the current precision mode"""
    return _precision_mode


def clear_cache():
    """Clear all cached models from memory"""
    global _models, _device, _is_initialized, _precision_mode

    print("\n⚠️ Clearing model cache...")

    # Move models to CPU and delete
    for model in _models.values():
        model.cpu()

    _models.clear()
    _device = None
    _is_initialized = False
    _precision_mode = "float32"

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("✓ Model cache cleared")


# ============================================================================
# Compatibility wrapper (optional - for backwards compatibility)
# ============================================================================

class ModelCacheManager:
    """Simple wrapper class for compatibility with existing code"""

    def initialize(self, config, upsampling_factor, device=None,
                   use_pytorch_weights=False, precision_mode="float32"):
        initialize_model_cache(config, upsampling_factor, device,
                               use_pytorch_weights, precision_mode)

    def get_model(self, model_num):
        return get_model(model_num)

    def get_device(self):
        return get_device()

    def get_precision_mode(self):
        return get_precision_mode()

    def clear_cache(self):
        clear_cache()


def get_model_cache():
    """Return a simple manager instance for compatibility"""
    return ModelCacheManager()


import pickle
import gzip
import os
import numpy as np
import torch
from collections import defaultdict
from typing import Optional, Tuple, Dict, Any
import threading
import queue


class MetadataManager:
    """
    Manages patch metadata with efficient asynchronous saving to disk.
    Stores all metadata for future analysis without slowing down processing.
    """

    def __init__(self, save_path: str, filename: str):
        """
        Initialize metadata manager

        Args:
            save_path: Directory to save metadata
            filename: Base filename for metadata files
        """
        self.save_path = save_path
        self.filename = filename
        os.makedirs(save_path, exist_ok=True)

        # In-memory storage (will be written to disk asynchronously)
        self._metadata = defaultdict(dict)
        self._original_frames = {}  # Store original frames for visualization

        # Asynchronous saving queue
        self._save_queue = queue.Queue()
        self._saver_thread = None
        self._stop_saving = threading.Event()

        # Statistics
        self.total_patches = 0

    def add_patch_metadata(self, frame_idx: int, patch_idx: int, metadata: Dict[str, Any]):
        """
        Add metadata for a patch (non-blocking)

        Args:
            frame_idx: Frame index
            patch_idx: Patch index within frame
            metadata: Dictionary containing all patch information:
                - valid_patch: torch.Tensor (valid area of the patch)
                - predicted_patch: np.ndarray (reconstructed patch)
                - curr_mean_noise: float
                - curr_std_noise: float
                - signal_amp: float
                - curr_emitter_density: float
                - difficulty: int (model selection)
        """
        # Store metadata in memory (lightweight operation)
        self._metadata[frame_idx][patch_idx] = metadata
        self.total_patches += 1

    def add_original_frame(self, frame_idx: int, frame_data: np.ndarray):
        """
        Store original frame for future visualization

        Args:
            frame_idx: Frame index
            frame_data: Original frame data (2D numpy array)
        """
        # Store as compressed format to save memory
        self._original_frames[frame_idx] = frame_data.astype(np.float32)

    def save_to_disk_async(self):
        """
        Start asynchronous background thread to save metadata to disk.
        This doesn't block the main processing pipeline.
        """
        if self._saver_thread is None or not self._saver_thread.is_alive():
            self._stop_saving.clear()
            self._saver_thread = threading.Thread(target=self._background_saver, daemon=True)
            self._saver_thread.start()

    def _background_saver(self):
        """Background thread function to save data to disk"""
        while not self._stop_saving.is_set():
            try:
                # Get data from queue with timeout
                save_task = self._save_queue.get(timeout=1.0)

                if save_task is None:  # Poison pill to stop thread
                    break

                # Perform the actual save operation
                filepath, data = save_task
                with gzip.open(filepath, 'wb', compresslevel=6) as f:
                    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

                self._save_queue.task_done()

            except queue.Empty:
                continue

    def save_all_metadata(self, wait_for_completion: bool = True):
        """
        Save all metadata to disk (can be done asynchronously)

        Args:
            wait_for_completion: If True, blocks until save is complete
        """
        base_name = os.path.splitext(self.filename)[0]

        # Prepare data for saving
        metadata_path = os.path.join(self.save_path, f'metadata_{base_name}.pkl.gz')
        frames_path = os.path.join(self.save_path, f'original_frames_{base_name}.pkl.gz')

        # Convert metadata to serializable format
        serializable_metadata = {}
        for frame_idx, patches in self._metadata.items():
            serializable_metadata[frame_idx] = {}
            for patch_idx, data in patches.items():
                # Convert tensors to numpy for serialization
                serialized_data = {}
                for key, value in data.items():
                    if isinstance(value, torch.Tensor):
                        serialized_data[key] = value.cpu().numpy()
                    else:
                        serialized_data[key] = value
                serializable_metadata[frame_idx][patch_idx] = serialized_data

        if wait_for_completion:
            # Synchronous save (blocking)
            print(f"\n start saving metadata:")
            with gzip.open(metadata_path, 'wb', compresslevel=6) as f:
                pickle.dump(serializable_metadata, f, protocol=pickle.HIGHEST_PROTOCOL)

            with gzip.open(frames_path, 'wb', compresslevel=6) as f:
                pickle.dump(self._original_frames, f, protocol=pickle.HIGHEST_PROTOCOL)

            print(f"✓ metadata saved: {metadata_path}")
            print(f"✓ original frames saved: {frames_path}")
        else:
            # Asynchronous save (non-blocking)
            self._save_queue.put((metadata_path, serializable_metadata))
            self._save_queue.put((frames_path, self._original_frames))

    def finalize(self):
        """Wait for all async saves to complete and cleanup"""
        if self._saver_thread is not None and self._saver_thread.is_alive():
            self._save_queue.put(None)  # Poison pill
            self._saver_thread.join(timeout=30)
            self._stop_saving.set()

    def clear_memory(self):
        """Clear in-memory data after saving to disk"""
        self._metadata.clear()
        self._original_frames.clear()
        self.total_patches = 0

    @staticmethod
    def load_metadata(save_path: str, filename: str) -> Tuple[Dict, Dict]:
        """
        Load saved metadata from disk

        Args:
            save_path: Directory containing metadata
            filename: Base filename

        Returns:
            metadata: Dictionary of all patch metadata
            original_frames: Dictionary of original frames
        """
        base_name = os.path.splitext(filename)[0]
        metadata_path = os.path.join(save_path, f'metadata_{base_name}.pkl.gz')
        frames_path = os.path.join(save_path, f'original_frames_{base_name}.pkl.gz')

        print(f"Loading metadata from {metadata_path}...")
        with gzip.open(metadata_path, 'rb') as f:
            metadata = pickle.load(f)

        print(f"Loading original frames from {frames_path}...")
        with gzip.open(frames_path, 'rb') as f:
            original_frames = pickle.load(f)

        print("✓ Metadata loaded successfully")
        return metadata, original_frames


class FrameReconstructor:
    """
    Reconstructs individual frames and patches from saved metadata.
    Only computes reconstructions when requested (lazy evaluation).
    """

    def __init__(self, metadata: Dict, original_frames: Dict,
                 num_patches: int, upsampling_factor: int):
        """
        Initialize frame reconstructor

        Args:
            metadata: Loaded metadata dictionary
            original_frames: Loaded original frames dictionary
            num_patches: Number of patches per dimension
            upsampling_factor: Upsampling factor used in reconstruction
        """
        self.metadata = metadata
        self.original_frames = original_frames
        self.num_patches = num_patches
        self.upsampling_factor = upsampling_factor

    def reconstruct_frame(self, frame_idx: int, overlap: int = 0,
                          save_path: Optional[str] = None) -> np.ndarray:
        """
        Reconstruct a complete frame from individual patch predictions

        Args:
            frame_idx: Frame index to reconstruct
            overlap: Overlap size used during patching
            save_path: If provided, save as TIFF file

        Returns:
            reconstructed_frame: Full reconstructed frame
        """
        if frame_idx not in self.metadata:
            raise ValueError(f"Frame {frame_idx} not found in metadata")

        patches_dict = self.metadata[frame_idx]

        if not patches_dict:
            raise ValueError(f"No patches found for frame {frame_idx}")

        # Get dimensions from first patch
        first_patch = list(patches_dict.values())[0]['predicted_patch']
        patch_h, patch_w = first_patch.shape

        # Calculate full frame dimensions
        frame_h = patch_h * self.num_patches
        frame_w = patch_w * self.num_patches

        # Initialize reconstruction
        reconstructed = np.zeros((frame_h, frame_w), dtype=np.float32)

        # Place each patch
        for patch_idx, patch_data in patches_dict.items():
            if patch_data['predicted_patch'] is None:
                continue

            predicted_patch = patch_data['predicted_patch']

            # Calculate position in full frame
            row = patch_idx // self.num_patches
            col = patch_idx % self.num_patches

            y1 = row * patch_h
            x1 = col * patch_w

            # Place patch (handle potential size mismatches)
            y2 = min(y1 + predicted_patch.shape[0], frame_h)
            x2 = min(x1 + predicted_patch.shape[1], frame_w)

            reconstructed[y1:y2, x1:x2] = predicted_patch[:y2 - y1, :x2 - x1]

        # Save if requested
        if save_path is not None:
            from a_file_loader import saveAsTIF
            pixel_size = 233 / self.upsampling_factor  # Default from config
            saveAsTIF(save_path, f'reconstructed_frame_{frame_idx}',
                      reconstructed, pixel_size)
            print(f"✓ Saved reconstructed frame {frame_idx} to {save_path}")

        return reconstructed

    def get_original_frame(self, frame_idx: int,
                           save_path: Optional[str] = None) -> np.ndarray:
        """
        Get original (pre-reconstruction) frame

        Args:
            frame_idx: Frame index
            save_path: If provided, save as TIFF file

        Returns:
            original_frame: Original frame data
        """
        if frame_idx not in self.original_frames:
            raise ValueError(f"Original frame {frame_idx} not found")

        original = self.original_frames[frame_idx]

        # Save if requested
        if save_path is not None:
            from a_file_loader import saveAsTIF
            pixel_size = 233  # Default from config
            saveAsTIF(save_path, f'original_frame_{frame_idx}',
                      original, pixel_size)
            print(f"✓ Saved original frame {frame_idx} to {save_path}")

        return original

    def get_patch_data(self, frame_idx: int, patch_idx: int) -> Dict[str, Any]:
        """
        Get all data for a specific patch

        Args:
            frame_idx: Frame index
            patch_idx: Patch index

        Returns:
            patch_data: Dictionary with all patch information
        """
        if frame_idx not in self.metadata:
            raise ValueError(f"Frame {frame_idx} not found")

        if patch_idx not in self.metadata[frame_idx]:
            raise ValueError(f"Patch {patch_idx} not found in frame {frame_idx}")

        return self.metadata[frame_idx][patch_idx]

    def compare_frames(self, frame_idx: int, save_path: Optional[str] = None):
        """
        Create side-by-side comparison of original and reconstructed frame

        Args:
            frame_idx: Frame index
            save_path: If provided, save comparison figure
        """
        import matplotlib.pyplot as plt

        original = self.get_original_frame(frame_idx)
        reconstructed = self.reconstruct_frame(frame_idx)

        fig, axes = plt.subplots(1, 2, figsize=(16, 8))

        axes[0].imshow(original, cmap='gray')
        axes[0].set_title(f'Original Frame {frame_idx}', fontsize=14)
        axes[0].axis('off')

        axes[1].imshow(reconstructed, cmap='gray')
        axes[1].set_title(f'Reconstructed Frame {frame_idx}', fontsize=14)
        axes[1].axis('off')

        plt.tight_layout()

        if save_path is not None:
            plt.savefig(os.path.join(save_path, f'comparison_frame_{frame_idx}.png'),
                        dpi=150, bbox_inches='tight')
            print(f"✓ Saved comparison for frame {frame_idx}")

        plt.close()


# ============================================================================
# Convenience Functions
# ============================================================================

def visualize_frame_from_file(metadata_path: str, filename: str,
                              frame_idx: int, num_patches: int = 8,
                              upsampling_factor: int = 8,
                              save_flag: bool = False,
                              output_path: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load metadata and visualize a specific frame

    Args:
        metadata_path: Path to metadata directory
        filename: Original filename
        frame_idx: Frame to visualize
        num_patches: Number of patches per dimension
        upsampling_factor: Upsampling factor
        save_flag: Whether to save output
        output_path: Where to save (if save_flag=True)

    Returns:
        original_frame: Original frame data
        reconstructed_frame: Reconstructed frame data
    """
    # Load metadata
    metadata, original_frames = MetadataManager.load_metadata(metadata_path, filename)

    # Create reconstructor
    reconstructor = FrameReconstructor(metadata, original_frames,
                                       num_patches, upsampling_factor)

    # Get frames
    original = reconstructor.get_original_frame(
        frame_idx,
        save_path=output_path if save_flag else None
    )

    reconstructed = reconstructor.reconstruct_frame(
        frame_idx,
        save_path=output_path if save_flag else None
    )

    # Create comparison if saving
    if save_flag and output_path is not None:
        reconstructor.compare_frames(frame_idx, output_path)

    return original, reconstructed

import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess
from typing import Optional, List, Tuple
from collections import defaultdict


def compare_trio_finals(image_path1: str, image_path2: str, image_path3: str,
                   heading1: str, heading2: str, heading3: str,
                   result_path: str):
    """
    Compare three final images side by side with headings

    Args:
        image_path1: Path to first image
        image_path2: Path to second image
        image_path3: Path to third image
        heading1: Heading for first image
        heading2: Heading for second image
        heading3: Heading for third image
        result_path: Path where to save the comparison figure
    """
    # Read all three images
    from PIL import Image
    img1 = np.array(Image.open(image_path1))
    img2 = np.array(Image.open(image_path2))
    img3 = np.array(Image.open(image_path3))

    # Create figure with 3 subplots side by side
    fig, axes = plt.subplots(1, 3, figsize=(21, 7))

    # Display first image
    axes[0].imshow(img1, cmap='gray' if len(img1.shape) == 2 else None)
    axes[0].set_title(heading1, fontsize=14, fontweight='bold')
    axes[0].axis('off')

    # Display second image
    axes[1].imshow(img2, cmap='gray' if len(img2.shape) == 2 else None)
    axes[1].set_title(heading2, fontsize=14, fontweight='bold')
    axes[1].axis('off')

    # Display third image
    axes[2].imshow(img3, cmap='gray' if len(img3.shape) == 2 else None)
    axes[2].set_title(heading3, fontsize=14, fontweight='bold')
    axes[2].axis('off')

    # Adjust spacing between subplots
    plt.tight_layout(pad=3.0)

    # Save the figure
    plt.savefig(result_path, dpi=150, bbox_inches='tight')

    plt.show()


def analyze_patch_statistics(result_folder: str, filename: str,
                             frame_idx: Optional[int] = None):
    """
    Analyze statistics of patches (difficulty distribution, SNR, etc.)

    Args:
        result_folder: Path to results folder
        filename: Original filename
        frame_idx: If provided, analyze only this frame. Otherwise, all frames.
    """
    print(f"\n{'=' * 70}")
    print(f"Analyzing patch statistics for {filename}")
    print(f"{'=' * 70}")

    # Load metadata
    metadata, _ = MetadataManager.load_metadata(result_folder, filename)

    # Collect statistics
    difficulties = []
    snr_values = []
    densities = []

    frames_to_analyze = [frame_idx] if frame_idx is not None else metadata.keys()

    for fid in frames_to_analyze:
        if fid not in metadata:
            continue

        for patch_data in metadata[fid].values():
            difficulties.append(patch_data['difficulty'])

            snr = patch_data['signal_amp'] / (patch_data['curr_mean_noise'] + 1e-10)
            snr_values.append(snr)

            densities.append(patch_data['curr_emitter_density'])

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Difficulty distribution
    unique, counts = np.unique(difficulties, return_counts=True)
    axes[0, 0].bar(unique, counts, color='steelblue', edgecolor='black')
    axes[0, 0].set_xlabel('Difficulty Level (Model)')
    axes[0, 0].set_ylabel('Number of Patches')
    axes[0, 0].set_title('Model Selection Distribution')
    axes[0, 0].grid(axis='y', alpha=0.3)

    # SNR distribution
    axes[0, 1].hist(snr_values, bins=50, color='coral', edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Signal-to-Noise Ratio')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('SNR Distribution')
    axes[0, 1].axvline(np.median(snr_values), color='red', linestyle='--',
                       label=f'Median: {np.median(snr_values):.2f}')
    axes[0, 1].legend()
    axes[0, 1].grid(axis='y', alpha=0.3)

    # Density distribution
    axes[1, 0].hist(densities, bins=50, color='lightgreen', edgecolor='black', alpha=0.7)
    axes[1, 0].set_xlabel('Emitter Density (per μm²)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Emitter Density Distribution')
    axes[1, 0].axvline(np.median(densities), color='darkgreen', linestyle='--',
                       label=f'Median: {np.median(densities):.2f}')
    axes[1, 0].legend()
    axes[1, 0].grid(axis='y', alpha=0.3)

    # Difficulty vs SNR scatter
    axes[1, 1].scatter(snr_values, difficulties, alpha=0.3, s=10, color='purple')
    axes[1, 1].set_xlabel('Signal-to-Noise Ratio')
    axes[1, 1].set_ylabel('Difficulty Level')
    axes[1, 1].set_title('Difficulty vs SNR')
    axes[1, 1].grid(alpha=0.3)

    plt.tight_layout()

    output_path = os.path.join(result_folder, f'patch_statistics_{os.path.splitext(filename)[0]}.png')
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved statistics to: {output_path}")

    plt.show()

    # Print summary
    print(f"\n{'=' * 70}")
    print(f"Total patches analyzed: {len(difficulties)}")
    print(f"\nDifficulty distribution:")
    for diff, count in zip(unique, counts):
        print(f"  Model {diff}: {count} patches ({100 * count / len(difficulties):.1f}%)")
    print(f"\nSNR statistics:")
    print(f"  Mean: {np.mean(snr_values):.2f}")
    print(f"  Median: {np.median(snr_values):.2f}")
    print(f"  Range: [{np.min(snr_values):.2f}, {np.max(snr_values):.2f}]")
    print(f"\nDensity statistics:")
    print(f"  Mean: {np.mean(densities):.2f} per μm²")
    print(f"  Median: {np.median(densities):.2f} per μm²")
    print(f"  Range: [{np.min(densities):.2f}, {np.max(densities):.2f}]")
    print(f"{'=' * 70}\n")  # ============================================================================


# Visualization Utilities - Easy-to-use functions for analyzing saved metadata
# ============================================================================

def patch_consistency_heatmap(result_folder: str, filename: str,
                              num_patches: int = 8,
                              model_names: List[str] = ['diff_1', 'diff_2', 'diff_3', 'diff_4'],
                              visualization_folder: str = "visualizations"):
    """
    Create heatmap showing consistency of patch-to-model assignments across all frames

    For each patch position, shows:
    - Which model was most frequently selected
    - What percentage of frames used that model (consistency %)

    Args:
        result_folder: Path to results folder
        filename: Original filename
        num_patches: Number of patches per dimension (creates num_patches x num_patches grid)
        model_names: List of model names
        visualization_folder: Subfolder name for visualizations
    """
    print(f"\n{'=' * 70}")
    print(f"Analyzing patch consistency across all frames")
    print(f"{'=' * 70}")

    # Create visualization folder
    vis_path = os.path.join(result_folder, visualization_folder)
    os.makedirs(vis_path, exist_ok=True)

    # Load metadata
    metadata, _ = MetadataManager.load_metadata(result_folder, filename)

    total_patches = num_patches * num_patches
    num_models = len(model_names)

    # Count model assignments for each patch position across all frames
    patch_model_counts = defaultdict(lambda: np.zeros(num_models, dtype=int))

    for frame_idx, patches in metadata.items():
        for patch_idx, patch_data in patches.items():
            difficulty = patch_data['difficulty']
            patch_model_counts[patch_idx][difficulty] += 1

    # Calculate consistency metrics
    dominant_models = np.zeros((num_patches, num_patches), dtype=int)
    consistency_percentages = np.zeros((num_patches, num_patches), dtype=float)

    for patch_idx in range(total_patches):
        row = patch_idx // num_patches
        col = patch_idx % num_patches

        counts = patch_model_counts[patch_idx]
        total_count = counts.sum()

        if total_count > 0:
            dominant_model = np.argmax(counts)
            consistency = (counts[dominant_model] / total_count) * 100

            dominant_models[row, col] = dominant_model
            consistency_percentages[row, col] = consistency

    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))

    # Left: Dominant model per patch
    from matplotlib.colors import ListedColormap
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'][:num_models]
    cmap = ListedColormap(colors)

    im1 = axes[0].imshow(dominant_models, cmap=cmap, vmin=0, vmax=num_models - 1)
    axes[0].set_title('Dominant Model per Patch Position', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Patch Column', fontsize=11)
    axes[0].set_ylabel('Patch Row', fontsize=11)

    # Add grid
    for i in range(num_patches + 1):
        axes[0].axhline(y=i - 0.5, color='white', linewidth=2)
        axes[0].axvline(x=i - 0.5, color='white', linewidth=2)

    # Add annotations
    for i in range(num_patches):
        for j in range(num_patches):
            patch_idx = i * num_patches + j
            model_idx = dominant_models[i, j]
            consistency = consistency_percentages[i, j]

            axes[0].text(j, i, f'{model_names[model_idx]}\n{consistency:.0f}%',
                         ha='center', va='center', fontsize=8,
                         color='white', fontweight='bold')

    # Colorbar
    cbar1 = plt.colorbar(im1, ax=axes[0], ticks=range(num_models))
    cbar1.set_ticklabels(model_names)
    cbar1.set_label('Model', rotation=270, labelpad=20, fontsize=11)

    # Right: Consistency percentage heatmap
    im2 = axes[1].imshow(consistency_percentages, cmap='RdYlGn', vmin=0, vmax=100)
    axes[1].set_title('Consistency Percentage', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Patch Column', fontsize=11)
    axes[1].set_ylabel('Patch Row', fontsize=11)

    # Add grid
    for i in range(num_patches + 1):
        axes[1].axhline(y=i - 0.5, color='gray', linewidth=1)
        axes[1].axvline(x=i - 0.5, color='gray', linewidth=1)

    # Add annotations
    for i in range(num_patches):
        for j in range(num_patches):
            pct = consistency_percentages[i, j]
            text_color = 'black' if pct > 50 else 'white'
            axes[1].text(j, i, f'{pct:.0f}%',
                         ha='center', va='center', fontsize=9,
                         color=text_color, fontweight='bold')

    # Colorbar
    cbar2 = plt.colorbar(im2, ax=axes[1])
    cbar2.set_label('Consistency (%)', rotation=270, labelpad=20, fontsize=11)

    plt.tight_layout()

    output_path = os.path.join(vis_path, 'patch_consistency_heatmap.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved to: {output_path}")

    plt.show()

    # Print statistics
    print(f"\n{'=' * 70}")
    print(f"Patch Consistency Statistics")
    print(f"{'=' * 70}")

    avg_consistency = np.mean(consistency_percentages)
    print(f"\nOverall average consistency: {avg_consistency:.1f}%")

    # Find most and least consistent patches
    flat_idx = np.argsort(consistency_percentages.ravel())

    print(f"\nMost consistent patches (always use same model):")
    for i in range(min(5, len(flat_idx))):
        idx = flat_idx[-(i + 1)]
        row, col = idx // num_patches, idx % num_patches
        patch_idx = row * num_patches + col
        consistency = consistency_percentages[row, col]
        model = model_names[dominant_models[row, col]]
        print(f"  Patch ({row}, {col}) [idx={patch_idx}]: {model} - {consistency:.1f}%")

    print(f"\nLeast consistent patches (split between models):")
    for i in range(min(5, len(flat_idx))):
        idx = flat_idx[i]
        row, col = idx // num_patches, idx % num_patches
        patch_idx = row * num_patches + col
        consistency = consistency_percentages[row, col]
        model = model_names[dominant_models[row, col]]

        # Show distribution
        counts = patch_model_counts[patch_idx]
        total = counts.sum()
        print(f"  Patch ({row}, {col}) [idx={patch_idx}]: {model} - {consistency:.1f}%")
        for m_idx, count in enumerate(counts):
            if count > 0:
                pct = (count / total) * 100
                print(f"    {model_names[m_idx]}: {pct:.1f}%")

    print(f"{'=' * 70}\n")


def compare_patches_across_frames(result_folder: str, filename: str,
                                  patch_idx: int,
                                  frame_indices: List[int],
                                  num_patches: int = 8,
                                  upsampling_factor: int = 8,
                                  visualization_folder: str = "visualizations"):
    """
    Compare a specific patch position across multiple frames

    Shows how the same spatial location evolves across different frames

    Args:
        result_folder: Path to results folder
        filename: Original filename
        patch_idx: Patch index to track (0 to num_patches^2 - 1)
        frame_indices: List of frame indices to compare
        num_patches: Number of patches per dimension
        upsampling_factor: Upsampling factor
        visualization_folder: Subfolder name for visualizations
    """
    print(f"\n{'=' * 70}")
    print(f"Comparing patch {patch_idx} across {len(frame_indices)} frames")
    print(f"{'=' * 70}")

    # Create visualization folder
    vis_path = os.path.join(result_folder, visualization_folder)
    os.makedirs(vis_path, exist_ok=True)

    # Load metadata
    metadata, _ = MetadataManager.load_metadata(result_folder, filename)

    # Calculate patch position
    row = patch_idx // num_patches
    col = patch_idx % num_patches
    print(f"Patch position: Row {row}, Column {col}")

    # Collect patch data
    n_frames = len(frame_indices)
    fig, axes = plt.subplots(2, n_frames, figsize=(4 * n_frames, 8))

    if n_frames == 1:
        axes = axes.reshape(-1, 1)

    for i, frame_idx in enumerate(frame_indices):
        if frame_idx not in metadata:
            print(f"Warning: Frame {frame_idx} not found in metadata")
            continue

        if patch_idx not in metadata[frame_idx]:
            print(f"Warning: Patch {patch_idx} not found in frame {frame_idx}")
            continue

        patch_data = metadata[frame_idx][patch_idx]

        # Get original patch (from full_patch tensor)
        original_patch = patch_data['full_patch']
        if isinstance(original_patch, np.ndarray):
            pass  # Already numpy
        else:
            original_patch = original_patch.numpy() if hasattr(original_patch, 'numpy') else np.array(original_patch)

        # Get predicted patch
        predicted_patch = patch_data['predicted_patch']
        if predicted_patch is None:
            print(f"Warning: No prediction for patch {patch_idx} in frame {frame_idx}")
            continue

        # Display original patch
        axes[0, i].imshow(original_patch, cmap='gray')
        axes[0, i].set_title(f'Frame {frame_idx}\nOriginal', fontsize=10)
        axes[0, i].axis('off')

        # Display predicted patch
        vmin, vmax = np.percentile(predicted_patch, [1, 99])
        axes[1, i].imshow(np.clip(predicted_patch, vmin, vmax), cmap='hot')

        # Add metadata to title
        difficulty = patch_data['difficulty']
        snr = patch_data['signal_amp'] / (patch_data['curr_mean_noise'] + 1e-10)
        density = patch_data['curr_emitter_density']

        axes[1, i].set_title(f'Reconstructed\nModel: {difficulty} | SNR: {snr:.1f}\nDensity: {density:.1f}',
                             fontsize=9)
        axes[1, i].axis('off')

    plt.tight_layout()

    output_path = os.path.join(vis_path, f'patch_{patch_idx}_comparison.png')
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved to: {output_path}")

    plt.show()

    # Print statistics
    print(f"\n{'=' * 70}")
    print(f"Patch {patch_idx} Statistics Across Frames")
    print(f"{'=' * 70}")

    difficulties = []
    snrs = []
    densities = []

    for frame_idx in frame_indices:
        if frame_idx in metadata and patch_idx in metadata[frame_idx]:
            patch_data = metadata[frame_idx][patch_idx]
            difficulties.append(patch_data['difficulty'])
            snr = patch_data['signal_amp'] / (patch_data['curr_mean_noise'] + 1e-10)
            snrs.append(snr)
            densities.append(patch_data['curr_emitter_density'])

    print(f"\nModel selection:")
    unique, counts = np.unique(difficulties, return_counts=True)
    for model, count in zip(unique, counts):
        print(f"  Model {model}: {count} times ({100 * count / len(difficulties):.1f}%)")

    print(f"\nSNR range: [{np.min(snrs):.2f}, {np.max(snrs):.2f}], mean: {np.mean(snrs):.2f}")
    print(f"Density range: [{np.min(densities):.2f}, {np.max(densities):.2f}], mean: {np.mean(densities):.2f}")
    print(f"{'=' * 70}\n")


def quick_visualize_frame(result_folder: str, filename: str, frame_idx: int,
                          num_patches: int = 8, upsampling_factor: int = 8,
                          save_output: bool = True, visualization_folder: str = "visualizations"):
    """
    Quick visualization of a single frame (original vs reconstructed)

    Args:
        result_folder: Path to results folder containing metadata
        filename: Original filename (e.g., 'data.tif')
        frame_idx: Frame index to visualize
        num_patches: Number of patches per dimension (default: 8)
        upsampling_factor: Upsampling factor (default: 8)
        save_output: Whether to save the visualization
        visualization_folder: Subfolder name for visualizations (default: "visualizations")
    """
    print(f"\n{'=' * 70}")
    print(f"Loading frame {frame_idx} from {filename}")
    print(f"{'=' * 70}")

    # Create visualization folder
    vis_path = os.path.join(result_folder, visualization_folder)
    os.makedirs(vis_path, exist_ok=True)

    # Load metadata
    metadata, original_frames = MetadataManager.load_metadata(result_folder, filename)

    # Create reconstructor
    reconstructor = FrameReconstructor(metadata, original_frames,
                                       num_patches, upsampling_factor)

    # Get frames
    original = reconstructor.get_original_frame(frame_idx)
    reconstructed = reconstructor.reconstruct_frame(frame_idx)

    # Create visualization
    fig, axes = plt.subplots(2, 5, figsize=(30, 12))

    # Original
    im0 = axes[0,0].imshow(original, cmap='gray')
    axes[0,0].set_title(f'Original Frame {frame_idx}', fontsize=14, fontweight='bold')
    axes[0,0].axis('off')
    fig.colorbar(im0, ax=axes[0,0])

    # Histogram of non-zero values for original
    nonzero_vals_original = original[original > 0]
    axes[1,0].hist(nonzero_vals_original, bins=50, color='orange', edgecolor='black', alpha=0.7)
    axes[1,0].set_title(f'Original Values Histogram', fontsize=14, fontweight='bold')
    axes[1,0].set_xlabel('Pixel Value')
    axes[1,0].set_ylabel('Frequency')
    axes[1,0].grid(True, alpha=0.3)

    # Reconstructed
    num_nonzero = np.count_nonzero(reconstructed)
    total_pixels = reconstructed.size
    percentage_nonzero = (num_nonzero / total_pixels) * 100

    im1 = axes[0,1].imshow(reconstructed, cmap='hot')
    axes[0,1].set_title(f'Reconstructed Frame {frame_idx} | Non-zero: {num_nonzero} ({percentage_nonzero:.2f}%)',
                      fontsize=14, fontweight='bold')
    axes[0,1].axis('off')
    fig.colorbar(im1, ax=axes[0, 1])

    # Add histogram of non-zero values
    nonzero_vals = reconstructed[reconstructed > 0]
    axes[1,1].hist(nonzero_vals, bins=50, color='orange', edgecolor='black', alpha=0.7)
    axes[1,1].set_title(f'Reconstructed Values Histogram', fontsize=14, fontweight='bold')
    axes[1,1].set_xlabel('Pixel Value')
    axes[1,1].set_ylabel('Frequency')
    axes[1,1].grid(True, alpha=0.3)

    # Normalized with gamma correction
    gamma = 0.25  # Adjust between 0.3-0.7 for different brightening levels
    reconstructed_normalized = reconstructed / reconstructed.max()  # Normalize to [0, 1]
    reconstructed_gamma = np.power(reconstructed_normalized, gamma) * reconstructed.max()  # Apply gamma and scale back

    im2 = axes[0,2].imshow(reconstructed_gamma, cmap='hot')
    axes[0,2].set_title(f'Normalized (γ={gamma})', fontsize=14, fontweight='bold')
    axes[0,2].axis('off')
    fig.colorbar(im2, ax=axes[0,2])

    # Add histogram of non-zero values for gamma-corrected
    nonzero_vals_gamma = reconstructed_gamma[reconstructed_gamma > 0]
    axes[1,2].hist(nonzero_vals_gamma, bins=50, color='orange', edgecolor='black', alpha=0.7)
    axes[1,2].set_title(f'Gamma-corrected Histogram', fontsize=14, fontweight='bold')
    axes[1,2].set_xlabel('Pixel Value')
    axes[1,2].set_ylabel('Frequency')
    axes[1,2].grid(True, alpha=0.3)

    # Normalized
    vmin, vmax = np.percentile(reconstructed, [0, 99.9])
    im4 = axes[0,3].imshow(np.clip(reconstructed, vmin, vmax), cmap='hot')
    axes[0,3].set_title(f'Normalized (1-99%)', fontsize=14, fontweight='bold')
    axes[0,3].axis('off')
    fig.colorbar(im4, ax=axes[0,3])
    plt.tight_layout()

    # Add histogram of non-zero values for reconstruction percentile correction
    reconstructed_clipped = np.clip(reconstructed, vmin, vmax)
    nonzero_vals_clipped = reconstructed_clipped[reconstructed_clipped > 0]

    axes[1,3].hist(nonzero_vals_clipped, bins=50, color='orange', edgecolor='black', alpha=0.7)
    axes[1,3].set_title(f'percentile corrected Histogram', fontsize=14, fontweight='bold')
    axes[1,3].set_xlabel('Pixel Value')
    axes[1,3].set_ylabel('Frequency')
    axes[1,3].grid(True, alpha=0.3)

    # NEW: Classic Min-Max Normalization
    reconstructed_min = reconstructed.min()
    reconstructed_max = reconstructed.max()
    reconstructed_classic = (reconstructed - reconstructed_min) / (reconstructed_max - reconstructed_min)
    reconstructed_classic = reconstructed_classic * 255  # Scale to 0-255 range

    im4 = axes[0, 4].imshow(reconstructed_classic, cmap='hot')
    axes[0, 4].set_title(f'Classic Min-Max Normalized', fontsize=14, fontweight='bold')
    axes[0, 4].axis('off')
    fig.colorbar(im4, ax=axes[0, 4])

    # Add histogram of non-zero values for classic normalization
    nonzero_vals_classic = reconstructed_classic[reconstructed_classic > 0]
    axes[1, 4].hist(nonzero_vals_classic, bins=50, color='orange', edgecolor='black', alpha=0.7)
    axes[1, 4].set_title(f'Classic Normalized Histogram', fontsize=14, fontweight='bold')
    axes[1, 4].set_xlabel('Pixel Value')
    axes[1, 4].set_ylabel('Frequency')
    axes[1, 4].grid(True, alpha=0.3)

    plt.tight_layout()

    if save_output:
        output_path = os.path.join(vis_path, f'frame_{frame_idx}_comparison.png')
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"✓ Saved to: {output_path}")

    plt.show()

    print(f"\n{'=' * 70}")
    print(f"Original shape: {original.shape}")
    print(f"Reconstructed shape: {reconstructed.shape}")
    print(f"Reconstruction range: [{reconstructed.min():.2f}, {reconstructed.max():.2f}]")
    print(f"{'=' * 70}\n")


def analyze_patch_statistics(result_folder: str, filename: str,
                             frame_idx: Optional[int] = None):
    """
    Analyze statistics of patches (difficulty distribution, SNR, etc.)

    Args:
        result_folder: Path to results folder
        filename: Original filename
        frame_idx: If provided, analyze only this frame. Otherwise, all frames.
    """
    print(f"\n{'=' * 70}")
    print(f"Analyzing patch statistics for {filename}")
    print(f"{'=' * 70}")

    # Load metadata
    metadata, _ = MetadataManager.load_metadata(result_folder, filename)

    # Collect statistics
    difficulties = []
    snr_values = []
    densities = []

    frames_to_analyze = [frame_idx] if frame_idx is not None else metadata.keys()

    for fid in frames_to_analyze:
        if fid not in metadata:
            continue

        for patch_data in metadata[fid].values():
            difficulties.append(patch_data['difficulty'])

            snr = patch_data['signal_amp'] / (patch_data['curr_mean_noise'] + 1e-10)
            snr_values.append(snr)

            densities.append(patch_data['curr_emitter_density'])

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Difficulty distribution
    unique, counts = np.unique(difficulties, return_counts=True)
    axes[0, 0].bar(unique, counts, color='steelblue', edgecolor='black')
    axes[0, 0].set_xlabel('Difficulty Level (Model)')
    axes[0, 0].set_ylabel('Number of Patches')
    axes[0, 0].set_title('Model Selection Distribution')
    axes[0, 0].grid(axis='y', alpha=0.3)

    # SNR distribution
    axes[0, 1].hist(snr_values, bins=50, color='coral', edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Signal-to-Noise Ratio')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('SNR Distribution')
    axes[0, 1].axvline(np.median(snr_values), color='red', linestyle='--',
                       label=f'Median: {np.median(snr_values):.2f}')
    axes[0, 1].legend()
    axes[0, 1].grid(axis='y', alpha=0.3)

    # Density distribution
    axes[1, 0].hist(densities, bins=50, color='lightgreen', edgecolor='black', alpha=0.7)
    axes[1, 0].set_xlabel('Emitter Density (per μm²)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Emitter Density Distribution')
    axes[1, 0].axvline(np.median(densities), color='darkgreen', linestyle='--',
                       label=f'Median: {np.median(densities):.2f}')
    axes[1, 0].legend()
    axes[1, 0].grid(axis='y', alpha=0.3)

    # Difficulty vs SNR scatter
    axes[1, 1].scatter(snr_values, difficulties, alpha=0.3, s=10, color='purple')
    axes[1, 1].set_xlabel('Signal-to-Noise Ratio')
    axes[1, 1].set_ylabel('Difficulty Level')
    axes[1, 1].set_title('Difficulty vs SNR')
    axes[1, 1].grid(alpha=0.3)

    plt.tight_layout()

    output_path = os.path.join(result_folder, f'patch_statistics_{os.path.splitext(filename)[0]}.png')
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved statistics to: {output_path}")

    plt.show()

    # Print summary
    print(f"\n{'=' * 70}")
    print(f"Total patches analyzed: {len(difficulties)}")
    print(f"\nDifficulty distribution:")
    for diff, count in zip(unique, counts):
        print(f"  Model {diff}: {count} patches ({100 * count / len(difficulties):.1f}%)")
    print(f"\nSNR statistics:")
    print(f"  Mean: {np.mean(snr_values):.2f}")
    print(f"  Median: {np.median(snr_values):.2f}")
    print(f"  Range: [{np.min(snr_values):.2f}, {np.max(snr_values):.2f}]")
    print(f"\nDensity statistics:")
    print(f"  Mean: {np.mean(densities):.2f} per μm²")
    print(f"  Median: {np.median(densities):.2f} per μm²")
    print(f"  Range: [{np.min(densities):.2f}, {np.max(densities):.2f}]")
    print(f"{'=' * 70}\n")


def create_movie_frames(result_folder: str, filename: str,
                        frame_range: Optional[Tuple[int, int]] = None,
                        num_patches: int = 8, upsampling_factor: int = 8,
                        visualization_folder: str = "visualizations",
                        fps: int = 10):
    """
    Create MP4 movie showing cumulative reconstruction (high quality)

    Args:
        result_folder: Path to results folder
        filename: Original filename
        frame_range: Tuple (start, end) or None for all frames
        num_patches: Number of patches per dimension
        upsampling_factor: Upsampling factor
        visualization_folder: Subfolder name for visualizations
        fps: Frames per second for the video
    """
    # Create movie folder inside visualizations
    vis_path = os.path.join(result_folder, visualization_folder)
    movie_folder = os.path.join(vis_path, "movie")
    os.makedirs(movie_folder, exist_ok=True)

    print(f"\n{'=' * 70}")
    print(f"Creating movie frames for {filename}")
    print(f"{'=' * 70}")

    # Load metadata
    metadata, original_frames = MetadataManager.load_metadata(result_folder, filename)

    # Create reconstructor
    reconstructor = FrameReconstructor(metadata, original_frames,
                                       num_patches, upsampling_factor)

    # Determine frame range
    all_frame_indices = sorted(original_frames.keys())
    if frame_range is not None:
        start, end = frame_range
        frame_indices = [f for f in all_frame_indices if start <= f < end]
    else:
        frame_indices = all_frame_indices

    print(f"Processing {len(frame_indices)} frames...")

    # Initialize cumulative reconstruction
    cumulative_reconstruction = None

    # Process each frame
    for i, frame_idx in enumerate(frame_indices):
        if i % 10 == 0:
            print(f"  Frame {i + 1}/{len(frame_indices)}")

        try:
            original = reconstructor.get_original_frame(frame_idx)
            frame_reconstruction = reconstructor.reconstruct_frame(frame_idx)

            # Accumulate reconstruction
            if cumulative_reconstruction is None:
                original_reconstruction = original
                cumulative_reconstruction = frame_reconstruction.copy()
            else:
                cumulative_reconstruction += frame_reconstruction
                original_reconstruction += original

            # Create side-by-side image with cumulative reconstruction
            fig, axes = plt.subplots(1, 2, figsize=(16, 8))

            axes[0].imshow(original_reconstruction, cmap='gray')
            axes[0].set_title(f'Original - Frame {frame_idx}', fontsize=14)
            axes[0].axis('off')

            # Show cumulative reconstruction (normalized)
            vmin, vmax = np.percentile(cumulative_reconstruction, [1, 99])
            axes[1].imshow(np.clip(cumulative_reconstruction, vmin, vmax), cmap='hot')
            axes[1].set_title(f'Cumulative Reconstruction - Frame {frame_idx} ({i + 1}/{len(frame_indices)})',
                              fontsize=14)
            axes[1].axis('off')

            plt.tight_layout()

            output_path = os.path.join(movie_folder, f'frame_{i:05d}.png')
            plt.savefig(output_path, dpi=150, bbox_inches='tight')
            plt.close()

        except Exception as e:
            print(f"  Warning: Failed to process frame {frame_idx}: {e}")

    print(f"\n✓ Saved {len(frame_indices)} frames to: {movie_folder}")

    # Create MP4 using ffmpeg
    print(f"\nCreating high-quality MP4 video...")
    base_name = os.path.splitext(filename)[0]
    output_video = os.path.join(vis_path, f"{base_name}_reconstruction_movie.mp4")

    # High-quality ffmpeg command
    ffmpeg_cmd = [
        'ffmpeg',
        '-y',  # Overwrite output file
        '-framerate', str(fps),
        '-i', os.path.join(movie_folder, 'frame_%05d.png'),
        '-c:v', 'libx264',
        '-preset', 'slow',  # Better compression
        '-crf', '18',  # High quality (18 is visually lossless)
        '-pix_fmt', 'yuv420p',
        '-vf', 'scale=trunc(iw/2)*2:trunc(ih/2)*2',  # Ensure even dimensions
        output_video
    ]

    try:
        subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
        print(f"✓ Video created successfully: {output_video}")
        print(f"\nVideo details:")
        print(f"  - Quality: High (CRF 18, visually lossless)")
        print(f"  - Framerate: {fps} fps")
        print(f"  - Codec: H.264")
        print(f"  - Total frames: {len(frame_indices)}")

    except subprocess.CalledProcessError as e:
        print(f"✗ Error creating video: {e}")
        print(f"  stderr: {e.stderr.decode()}")
        print(f"\nYou can manually create the video with:")
        print(
            f"  ffmpeg -framerate {fps} -i {movie_folder}/frame_%05d.png -c:v libx264 -crf 18 -pix_fmt yuv420p {output_video}")
    except FileNotFoundError:
        print(f"✗ ffmpeg not found. Please install ffmpeg to create videos.")
        print(f"  Frames saved in: {movie_folder}")
        print(f"\nAfter installing ffmpeg, run:")
        print(
            f"  ffmpeg -framerate {fps} -i {movie_folder}/frame_%05d.png -c:v libx264 -crf 18 -pix_fmt yuv420p {output_video}")

    print(f"{'=' * 70}\n")


def compare_multiple_frames(result_folder: str, filename: str,
                            frame_indices: List[int],
                            num_patches: int = 8, upsampling_factor: int = 8,
                            visualization_folder: str = "visualizations"):
    """
    Create a grid comparison of multiple frames

    Args:
        result_folder: Path to results folder
        filename: Original filename
        frame_indices: List of frame indices to compare
        num_patches: Number of patches per dimension
        upsampling_factor: Upsampling factor
        visualization_folder: Subfolder name for visualizations
    """
    print(f"\n{'=' * 70}")
    print(f"Comparing {len(frame_indices)} frames")
    print(f"{'=' * 70}")

    # Create visualization folder
    vis_path = os.path.join(result_folder, visualization_folder)
    os.makedirs(vis_path, exist_ok=True)

    # Load metadata
    metadata, original_frames = MetadataManager.load_metadata(result_folder, filename)

    # Create reconstructor
    reconstructor = FrameReconstructor(metadata, original_frames,
                                       num_patches, upsampling_factor)

    # Create grid
    n_frames = len(frame_indices)
    fig, axes = plt.subplots(2, n_frames, figsize=(5 * n_frames, 10))

    if n_frames == 1:
        axes = axes.reshape(-1, 1)

    for i, frame_idx in enumerate(frame_indices):
        original = reconstructor.get_original_frame(frame_idx)
        reconstructed = reconstructor.reconstruct_frame(frame_idx)

        # Original
        axes[0, i].imshow(original, cmap='gray')
        axes[0, i].set_title(f'Original {frame_idx}', fontsize=12)
        axes[0, i].axis('off')

        # Reconstructed
        cap = np.percentile(reconstructed, 99.5)
        reconstructed[reconstructed > cap] = cap

        #vmin, vmax = np.percentile(reconstructed, [1, 99])
        #axes[1, i].imshow(np.clip(reconstructed, vmin, vmax), cmap='hot')
        axes[1, i].imshow(reconstructed, cmap='hot', vmin=0, vmax=cap)
        axes[1, i].set_title(f'Reconstructed {frame_idx}', fontsize=12)
        axes[1, i].axis('off')

    plt.tight_layout()

    output_path = os.path.join(vis_path, f'comparison_grid.png')
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved to: {output_path}")

    plt.show()
    print(f"{'=' * 70}\n")


import torch
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
import csv
from tqdm import tqdm
from collections import defaultdict

# ============================================================================
# Configuration
# ============================================================================
class Config:
    """Configuration for AutoDS inference"""
    # Data paths
    Result_folder = "../Results/TOM20_10nM/v26/fp16" #@param {type:"string"}


    # --- Quiet/Preview flags ------------------------------------------------------
    QUIET = False  # no training/inference chatter unless set to False
    HEADLESS_PREVIEW = True  # set True if you want to see the preview figures

    # --- Metadata collection flag -------------------------------------------------
    use_patch_metadata = False  # Set to False to disable metadata collection

    PRECISION_MODE = 'fp16'  # Options: 'float32', 'fp16', 'int8'

    # Detection parameters
    threshold = 10 #@param {type:"number"}
    neighborhood_size = 3 #@param {type:"integer"}
    use_local_average = True #@param {type:"boolean"}


    # Patch parameters
    num_patches = 8 #@param {type:"number"}
    overlap = 4 #@param {type:"number"}
    patch_batch_size = 32 #@param {type:"number"}
    frame_batch_size = 10 #@param {type:"number"}

    # Imaging parameters
    interpolate_based_on_imaging_parameters = True #@param {type:"boolean"}
    get_pixel_size_from_file = False #@param {type:"boolean"}
    pixel_size = 233 #@param {type:"number"}
    wavelength = 233 #@param {type:"number"}
    numerical_aperture = 1.49 #@param {type:"number"}


    # Processing parameters
    chunk_size = 10000 - 16

    # Timing parameters
    enable_timing = True  # Set to True to enable detailed timing profiling #@param {type:"boolean"}

    # Model parameters
    use_pytorch_weights = False  # Set to True to use .pth weights, False to use .h5 weights #@param {type:"boolean"}

    # Model paths
    prediction_model_path = "/content/AutoDS_models"
    model_names = ['diff_1', 'diff_2', 'diff_3', 'diff_4']

    # Model manifest for downloading
    MODEL_MANIFEST = {
        "diff_1": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_1/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
        "diff_2": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_2/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
        "diff_3": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_3/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
        "diff_4": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_4/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
    }


# Setup
config = Config()

# ============================================================================
# Entry Point
# ============================================================================
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if device.type != 'cuda':
        log('You do not have GPU access.')
        log('Did you change your runtime?')
        log('If the runtime settings are correct then GPU might not be allocated to your session.')
        log('Expect slow performance. To access GPU try reconnecting later.')
    else:
        log('You have GPU access')
        #if config.PRECISION_MODE != "float32":
        #    log(f'Precision mode: {config.PRECISION_MODE}')
        #if config.use_patch_metadata:
        #    log(f'Metadata collection: Enabled')

    # Initialize timing profiler
    profiler = timing_profiler(enabled=config.enable_timing)

    config.prediction_model_path = ensure_models(config.model_names, target_root=config.prediction_model_path,
                                                 model_manifest=config.MODEL_MANIFEST)

    MAX_FILE_GB = 5.0  # warn & skip when file is larger than this

    # PSF parameters
    psf_sigma_nm = 0.21 * config.wavelength / config.numerical_aperture
    psf_sigma_pixels = psf_sigma_nm / config.pixel_size

    if config.get_pixel_size_from_file:
        pixel_size = None

    # Load model metadata
    matfile = sio.loadmat(os.path.join(config.prediction_model_path, config.model_names[0], 'model_metadata.mat'))
    try:
        model_wavelength = np.array(matfile['wavelength'].item())
    except:
        model_wavelength = None
    try:
        model_NA = np.array(matfile['numerical_aperture'].item())
    except:
        model_NA = None
    try:
        model_pixel_size = np.array(matfile['pixel_size'].item())
    except:
        model_pixel_size = None

    if os.path.isdir(Data_folder):
        # iterate both TIFF and ND2
        for filename in list_files_multi(Data_folder, extensions=['tif', 'tiff', 'nd2']):
            print(f"\nStart processing file: {filename}")

            # Install nd2 reader only when needed (optional)
            if filename.lower().endswith('.nd2'):
                try:
                    import nd2
                except Exception:
                    import subprocess
                    import sys

                    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'nd2'])
                    import nd2

            in_path = os.path.join(config.Data_folder, filename)

            # --------- file size guard ----------
            try:
                file_size_gb = os.path.getsize(in_path) / 1e9
                if file_size_gb > MAX_FILE_GB:
                    print(f"\n⚠️  {filename}: {file_size_gb:.2f} GB > {MAX_FILE_GB:.2f} GB.")
                    print("   Video size is too big, please use Google Colab Pro or run locally.")
                    continue
            except Exception:
                pass

            # --- Resolve pixel size if requested ---
            if config.get_pixel_size_from_file:
                if is_tiff(in_path):
                    with catch_oom("reading TIFF pixel size", filename):
                        pixel_size, _, _ = getPixelSizeTIFFmetadata(in_path, True)
                elif is_nd2(in_path):
                    with catch_oom("reading ND2 pixel size", filename):
                        px_nm, _, _ = getPixelSizeND2metadata(in_path, True)
                        pixel_size = px_nm if px_nm is not None else pixel_size

            # --- Common model params ---
            upsampling_factor = np.array(matfile['upsampling_factor']).item()
            try:
                L2_weighting_factor = np.array(matfile['Normalization factor']).item()
            except:
                L2_weighting_factor = 100

            # save all models to cache
            initialize_model_cache(config, upsampling_factor, device,
                                   use_pytorch_weights=config.use_pytorch_weights,
                                   precision_mode=config.PRECISION_MODE)

            # --- Choose reader & frame count ---
            number_of_frames, frame_iter = None, None
            with catch_oom("opening stack", filename):
                if is_tiff(in_path):
                    number_of_frames = count_tiff_frames(in_path)
                    frame_iter = iter_tiff_frames(in_path)
                    log(f'\nLoaded tiff stack with {number_of_frames} frames')
                elif is_nd2(in_path):
                    number_of_frames = count_nd2_frames(in_path)
                    frame_iter = iter_nd2_frames(in_path)
                    log(f'\nLoaded ND2 stack with ~{number_of_frames} planes (T*Z*C)')
                else:
                    log(f"Skipping unsupported file: {filename}")

            if frame_iter is None:
                print(f"⚠️  Skipping {filename} due to earlier error.")
                continue

            # Initialize patch lists for each model (like v18)
            patches_list = [[] for _ in config.model_names]
            patch_indices_list = [[] for _ in config.model_names]
            frame_numbers = [[] for _ in config.model_names]

            # Initialize patch manager (only if metadata is enabled)
            if config.use_patch_metadata:
                metadata_manager = MetadataManager(config.Result_folder, filename)

            # Initialize accumulator variables
            M, N = None, None
            sum_image = None
            patchwise_recon = None
            frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []
            total_selected_model_hist = np.zeros(len(config.model_names), dtype=int)

            # Progress bar for overall process
            pbar = tqdm(total=number_of_frames, desc="Processing frames")

            for frame_start in range(0, number_of_frames, config.frame_batch_size):
                frame_end = min(frame_start + config.frame_batch_size, number_of_frames)
                frame_batch_size = frame_end - frame_start

                # Collect patches from multiple frames
                all_valid_patches = []
                all_full_patches = []
                all_patches_local_indices = []
                all_frames_numbers = []
                all_patches_offset = []

                profiler.start_timer("frame reading and splitting")

                # Collect frames
                frames_list = []
                for frame_idx in range(frame_start, frame_end):
                    frame_i = next(frame_iter)

                    # Store original frame for future visualization
                    if config.use_patch_metadata:
                        metadata_manager.add_original_frame(frame_idx, frame_i.copy())

                    # Initialize sum_image and dimensions on first frame
                    if sum_image is None:
                        sum_image = np.zeros_like(frame_i, dtype=np.float32)

                        # Interpolate first to get actual dimensions
                        if config.interpolate_based_on_imaging_parameters:
                            temp_frame = interpolate_frames(
                                frame_i,
                                model_pixel_size, config.pixel_size,
                                model_wavelength, config.wavelength,
                                model_NA, config.numerical_aperture
                            )[0]
                            M, N = temp_frame.shape
                        else:
                            M, N = frame_i.shape

                    # Accumulate for preview
                    sum_image += frame_i.astype(np.float32) / number_of_frames

                    # Interpolate frame
                    if config.interpolate_based_on_imaging_parameters:
                        frame_i = interpolate_frames(
                            frame_i,
                            model_pixel_size, config.pixel_size,
                            model_wavelength, config.wavelength,
                            model_NA, config.numerical_aperture
                        )[0]
                    frames_list.append(frame_i)

                # Preprocess on GPU
                frames_torch = torch.from_numpy(np.array(frames_list)).float().to(device)
                fproc_tensor, frames_offsets = preprocess_frames_batch(frames_torch, device)

                # Split all frames to patches (GPU)
                all_patches_tensor = split_image_to_patches_batch(
                    fproc_tensor,
                    config.num_patches,
                    config.overlap,
                    device=device
                )

                for frame_idx in range(frame_start, frame_end):
                    offset = frames_offsets[frame_idx - frame_start].cpu().item()
                    patches = all_patches_tensor[frame_idx - frame_start]

                    # Process each patch
                    for m in range(config.num_patches):
                        for n in range(config.num_patches):
                            down = config.overlap if m == 0 else 0
                            up = (M // config.num_patches) - config.overlap if m == config.num_patches - 1 else (
                                    M // config.num_patches)
                            left = config.overlap if n == 0 else 0
                            right = (N // config.num_patches) - config.overlap if n == config.num_patches - 1 else (
                                    N // config.num_patches)

                            local_patch_idx = m * config.num_patches + n
                            full_patch = patches[local_patch_idx]
                            valid_patch = full_patch[down:up, left:right]

                            all_full_patches.append(full_patch)
                            all_valid_patches.append(valid_patch)
                            all_patches_local_indices.append(local_patch_idx)
                            all_patches_offset.append(offset)
                            all_frames_numbers.append(frame_idx)

                profiler.stop_timer("frame reading and splitting")

                profiler.start_timer("patch features extraction")
                # Group patches by size and extract features
                shape_groups = defaultdict(lambda: {'patches': [], 'indices': [], 'offsets': []})

                for idx, patch in enumerate(all_valid_patches):
                    shape = patch.shape
                    shape_groups[shape]['patches'].append(patch)
                    shape_groups[shape]['indices'].append(idx)
                    shape_groups[shape]['offsets'].append(all_patches_offset[idx])

                # Process each size group
                all_features = []

                for shape, group_data in shape_groups.items():
                    patches_tensor = torch.stack(group_data['patches'])
                    offsets_array = np.array(group_data['offsets'])

                    features_batch = extract_features_batch(
                        patches_tensor,
                        config.pixel_size,
                        psf_sigma_pixels,
                        offsets_array,
                        verbose=False,
                        device=device
                    )

                    for feat, idx in zip(features_batch, group_data['indices']):
                        all_features.append((feat, idx))

                profiler.stop_timer("patch features extraction")
                profiler.start_timer("patch classification")

                # Classify and accumulate patches for reconstruction
                for features, idx in all_features:
                    curr_mean_noise, curr_std_noise, signal_amp, curr_emitter_density = features

                    # Skip invalid patches
                    if signal_amp == 0 or curr_mean_noise == 0:
                        continue
                    if any(np.isnan(v) for v in (signal_amp, curr_mean_noise, curr_std_noise, curr_emitter_density)):
                        continue

                    # Choose difficulty level
                    difficulty_choice = ChooseNetByDifficulty_2025(
                        curr_emitter_density,
                        signal_amp / curr_mean_noise
                    )
                    total_selected_model_hist[difficulty_choice] += 1

                    # Store patch data
                    patches_list[difficulty_choice].append(all_full_patches[idx])
                    patch_indices_list[difficulty_choice].append(all_patches_local_indices[idx])
                    frame_numbers[difficulty_choice].append(all_frames_numbers[idx])

                    # Add to patch manager (if metadata enabled)
                    if config.use_patch_metadata:
                        # Store metadata for future analysis
                        metadata_manager.add_patch_metadata(
                            frame_idx=all_frames_numbers[idx],
                            patch_idx=all_patches_local_indices[idx],
                            metadata={
                                'valid_patch': all_valid_patches[idx].cpu().numpy(),  # Move to CPU to save GPU memory
                                'curr_mean_noise': curr_mean_noise,
                                'curr_std_noise': curr_std_noise,
                                'signal_amp': signal_amp,
                                'curr_emitter_density': curr_emitter_density,
                                'difficulty': difficulty_choice,
                                'predicted_patch': None  # Will be filled after reconstruction
                            }
                        )

                profiler.stop_timer("patch classification")

                # Initialize reconstruction array on first batch and move it to the GPU
                if patchwise_recon is None:
                    M, N = fproc_tensor.shape[1], fproc_tensor.shape[2]
                    patchwise_recon = torch.zeros(M * upsampling_factor, N * upsampling_factor,
                                                  dtype=torch.float32, device=device)

                # Process with each model
                for model_num, model_name in enumerate(config.model_names):
                    if not patches_list[model_num]:
                        continue

                    # Reconstruct using CACHED model
                    pw_recon, loc_list, predicted_patches = reconstruct_patches_2025_pytorch(
                        torch.stack(patches_list[model_num]),
                        patch_indices_list[model_num],
                        frame_numbers[model_num],
                        model_num,
                        config.num_patches,
                        config.overlap * upsampling_factor,
                        number_of_frames,
                        config.threshold,
                        neighborhood_size=config.neighborhood_size,
                        use_local_avg=config.use_local_average,
                        upsampling_factor=upsampling_factor,
                        pixel_size=config.pixel_size,
                        batch_size=config.patch_batch_size,
                        L2_weighting_factor=L2_weighting_factor,
                        profiler=profiler,
                        precision_mode=config.PRECISION_MODE,
                        use_metadata=config.use_patch_metadata
                    )

                    if config.use_patch_metadata:
                        # Store predicted patches back into patch_manager
                        for i, (frame_idx, patch_idx) in enumerate(zip(frame_numbers[model_num], patch_indices_list[model_num])):
                            predicted_patch = predicted_patches[i]
                            # Update metadata manager with prediction
                            metadata_dict = metadata_manager._metadata[frame_idx][patch_idx]
                            metadata_dict['predicted_patch'] = predicted_patch

                    # Accumulate results
                    frame_number_list.extend(loc_list[0])
                    x_nm_list.extend(loc_list[1])
                    y_nm_list.extend(loc_list[2])
                    confidence_au_list.extend(loc_list[3])

                    patchwise_recon[:M // config.num_patches * upsampling_factor * config.num_patches,
                                    :N // config.num_patches * upsampling_factor * config.num_patches] += pw_recon

                # Clear patches lists after processing to free memory (like v18)
                for i in range(len(patches_list)):
                    patches_list[i].clear()
                    patch_indices_list[i].clear()
                    frame_numbers[i].clear()

                # Update process bar
                pbar.update(frame_batch_size)

            # close progress bar
            pbar.close()

            # Create output folder if needed
            if not os.path.exists(config.Result_folder):
                print('Result folder was created.')
                os.makedirs(config.Result_folder, exist_ok=True)

            # Save results
            os.makedirs(config.Result_folder, exist_ok=True)
            ext = '_avg' if config.use_local_average else '_max'
            base = os.path.splitext(filename)[0]

            # Histogram of overall patches distribution by models
            selected_model_hist = total_selected_model_hist

            # Histogram 1: Overall patches distribution by models
            print(f"selected_model_hist = {total_selected_model_hist}")
            plt.figure(figsize=(10, 6))
            plt.bar(np.arange(len(config.model_names)), selected_model_hist, width=0.8)
            plt.xticks(np.arange(len(config.model_names)), config.model_names)
            plt.xlabel('Selected Model')
            plt.ylabel('Number of Patches')
            plt.title('Model Selection Distribution')
            plt.tight_layout()
            plt.savefig(os.path.join(config.Result_folder, f'model_selection_{os.path.splitext(filename)[0]}.png'))
            plt.close()

            #print(f"\n{'=' * 70}")
            #print(f"Total localizations found: {len(frame_number_list)}")
            #print(f"{'=' * 70}")

            # Save localizations
            with open(os.path.join(config.Result_folder, f'Localizations_{base}{ext}.csv'), "w", newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])
                sort_ind = np.argsort(frame_number_list)
                locs = list(zip(
                    list(np.array(frame_number_list)[sort_ind]),
                    list(np.array(x_nm_list)[sort_ind]),
                    list(np.array(y_nm_list)[sort_ind]),
                    list(np.array(confidence_au_list)[sort_ind])
                ))
                writer.writerows(locs)

            #print(f"Saved {len(frame_number_list)} localizations")

            # move the reconstructed image to CPU for saving
            patchwise_recon = patchwise_recon.cpu().numpy()

            # Save reconstruction
            pw_recon_tif = np.copy(patchwise_recon)
            cap = np.percentile(pw_recon_tif, 99.5)
            pw_recon_tif[pw_recon_tif > cap] = cap
            saveAsTIF(config.Result_folder, f"Predicted_patchwise_{base}", pw_recon_tif,
                      config.pixel_size / upsampling_factor)

            # Create preview
            fig, axes = plt.subplots(1, 3, figsize=(20, 16))
            axes[0].axis('off')
            axes[0].imshow(sum_image)
            axes[0].set_title('Original', fontsize=15)
            axes[1].axis('off')
            axes[1].imshow(patchwise_recon)
            axes[1].set_title('Prediction', fontsize=15)
            axes[2].axis('off')
            axes[2].imshow(np.clip(patchwise_recon,
                                   np.percentile(patchwise_recon, 1),
                                   np.percentile(patchwise_recon, 99)))
            axes[2].set_title('Normalized Prediction', fontsize=15)
            plt.tight_layout()
            plt.savefig(os.path.join(config.Result_folder, f'preview_{base}.png'), dpi=150)
            plt.close()

            if config.enable_timing:
                # Print timing summary
                profiler.print_timing_summary()
                # reset timing for next file
                profiler.reset()

            if config.use_patch_metadata:
                # Save all metadata to disk (asynchronously, won't block)
                metadata_manager.save_all_metadata(wait_for_completion=True)
                metadata_manager.clear_memory()

                # Start async saver thread (won't slow down processing)
                metadata_manager.save_to_disk_async()

            print(f"\nCompleted processing file: {filename}")



        # METADATA: At the very end, wait for all async saves to complete
        if config.use_patch_metadata:
            metadata_manager.finalize()


You have GPU access
[models] found: diff_1
[models] found: diff_2
[models] found: diff_3
[models] found: diff_4


NameError: name 'Data_folder' is not defined

# **V2: PyTorch Version**
1. full pyTorch compatibility
2. frame-wize preprocessing (instead of model-wize)

In [None]:

import os
import urllib.request
from contextlib import contextmanager

import numpy as np
import tifffile as tiff
from PIL import Image
from PIL.TiffTags import TAGS

def log(*args, **kwargs):
    if not config.QUIET:
        print(*args, **kwargs)

def list_files_multi(directory, extensions):
    exts = {('.' + e.lower()) for e in extensions}
    for f in os.listdir(directory):
        if os.path.splitext(f)[1].lower() in exts:
            yield f

@contextmanager
def catch_oom(phase: str, detail: str = "", on_oom="continue"):
    """
    Wrap any memory-heavy block. Prints a friendly message on OOM and continues.
    on_oom: "continue" (default) just prints and returns; any other value re-raises.
    """
    try:
        yield
    except Exception as e:
        if _is_oom(e):
            print(f"\n⚠️  OOM while {phase}{(' - ' + detail) if detail else ''}.")
            print("   Tip: reduce chunk_size/batch_size/upsampling, or downsample input.")
            if isinstance(e, torch.cuda.OutOfMemoryError):
                # PyTorch OOM messages are in str(e) directly
                msg_line = str(e).splitlines()[0][:200]
                print("   PyTorch says:", msg_line)
            else:
                traceback.print_exc(limit=1, file=sys.stdout)
            if on_oom != "continue":
                raise
        else:
            # Non-OOM: re-raise so real bugs are visible
            raise

# ============================================================================
# 1. TIFF File Operations
# ============================================================================

def getPixelSizeTIFFmetadata(TIFFpath, display=False):
    """Extract pixel size from TIFF metadata"""
    with Image.open(TIFFpath) as img:
        meta_dict = {TAGS[key]: img.tag[key] for key in img.tag.keys()}

    ResolutionUnit = meta_dict['ResolutionUnit'][0]
    width = meta_dict['ImageWidth'][0]
    height = meta_dict['ImageLength'][0]
    xResolution = meta_dict['XResolution'][0]

    if len(xResolution) == 1:
        xResolution = xResolution[0]
    elif len(xResolution) == 2:
        xResolution = xResolution[0] / xResolution[1]
    else:
        print('Image resolution not defined.')
        xResolution = 1

    if ResolutionUnit == 2:
        pixel_size = 0.025 * 1e9 / xResolution
    elif ResolutionUnit == 3:
        pixel_size = 0.01 * 1e9 / xResolution
    else:
        print('Resolution unit not defined. Assuming: um')
        pixel_size = 1e3 / xResolution

    if display:
        print(f'Pixel size from metadata: {pixel_size} nm')
        print(f'Image size: {width}x{height}')

    return pixel_size, width, height

def saveAsTIF(path, filename, array, pixel_size):
    """Save array as TIFF with metadata"""
    if array.dtype == np.uint16:
        mode = 'I;16'
    elif array.dtype == np.uint32:
        mode = 'I'
    else:
        mode = 'F'

    if len(array.shape) == 2:
        im = Image.fromarray(array)
        im.save(os.path.join(path, filename + '.tif'),
               mode=mode,
               resolution_unit=3,
               resolution=0.01 * 1e9 / pixel_size)
    elif len(array.shape) == 3:
        imlist = []
        for frame in array:
            imlist.append(Image.fromarray(frame))
        imlist[0].save(os.path.join(path, filename + '.tif'),
                      save_all=True,
                      append_images=imlist[1:],
                      mode=mode,
                      resolution_unit=3,
                      resolution=0.01 * 1e9 / pixel_size)

def is_tiff(path):
    """Check if file is TIFF"""
    return path.lower().endswith(('.tif', '.tiff'))

def iter_tiff_frames(path):
    """Iterate over TIFF frames"""
    with tiff.TiffFile(path) as tif:
        for page in tif.pages:
            yield page.asarray().astype(np.float32)

def count_tiff_frames(path):
    """Count frames in TIFF file"""
    with tiff.TiffFile(path) as tif:
        return len(tif.pages)

# ============================================================================
# 2. ND2 File Operations
# ============================================================================

def is_nd2(path):
    """Check if file is ND2"""
    try:
        import nd2
        return nd2.is_supported_file(path)
    except Exception:
        return path.lower().endswith(".nd2")

def count_nd2_frames(path):
    """Count frames in ND2 file"""
    import nd2
    with nd2.ND2File(path) as f:
        try:
            return len(f.loop_indices)
        except Exception:
            sz = getattr(f, "sizes", {}) or {}
            prod = 1
            for ax in ("T", "Z", "C", "V"):
                prod *= int(sz.get(ax, 1))
            return prod

def _nd2_to_2d(arr, channel=None):
    """Convert ND2 frame to 2D"""
    a = np.asarray(arr)
    if a.ndim == 2:
        return a
    if a.ndim == 3:
        if a.shape[-1] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[-1]) else 0
            return a[..., idx]
        if a.shape[0] in (1, 3, 4):
            idx = channel if (channel is not None and channel < a.shape[0]) else 0
            return a[idx, ...]
        return a.mean(axis=0)
    a = a.squeeze()
    return a if a.ndim == 2 else a.reshape(a.shape[-2], a.shape[-1])

def iter_nd2_frames(path, channel=None):
    """Iterate over ND2 frames"""
    import nd2
    n = count_nd2_frames(path)
    with nd2.ND2File(path) as f:
        for i in range(n):
            fr = f.read_frame(i)
            fr2d = _nd2_to_2d(fr, channel=channel)
            yield fr2d.astype(np.float32, copy=False)

def getPixelSizeND2metadata(path, display=False):
    """Extract pixel size from ND2 metadata"""
    import nd2
    with nd2.ND2File(path) as f:
        vox_um = getattr(f, "voxel_size", None)
        if vox_um is None:
            return None, None, None
        px_nm = vox_um[2] * 1e3
        try:
            h, w = f.shape[-2], f.shape[-1]
        except Exception:
            h = w = None
        if display:
            print(f"Pixel size (ND2): {px_nm:.2f} nm | image ~ {w}x{h}")
        return px_nm, w, h

# ============================================================================
# 3. Drift Correction Functions
# ============================================================================

def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):
    """Apply drift correction to localizations"""
    n_locs = xc_array.shape[0]
    xc_array_Corr = np.empty(n_locs)
    yc_array_Corr = np.empty(n_locs)

    for loc in range(n_locs):
        xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc] - 1]
        yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc] - 1]

    return xc_array_Corr, yc_array_Corr

def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size=(64, 64), pixel_size=100):
    """Convert localizations to histogram image"""
    w, h = image_size
    locImage = np.zeros(image_size)
    n_locs = len(xc_array)

    for e in range(n_locs):
        y_idx = int(max(min(round(yc_array[e] / pixel_size), w - 1), 0))
        x_idx = int(max(min(round(xc_array[e] / pixel_size), h - 1), 0))
        locImage[y_idx][x_idx] += 1

    return locImage

def estimate_drift_com_nm(img1, img2, pixel_size_nm, sigma=1.0, patch_radius=3):
    """Estimate drift using center of mass of cross-correlation"""
    from scipy.ndimage import gaussian_filter
    from scipy.signal import fftconvolve

    # Smooth images
    img1_smooth = gaussian_filter(img1.astype(np.float32), sigma=sigma)
    img2_smooth = gaussian_filter(img2.astype(np.float32), sigma=sigma)

    # Cross-correlation
    corr = fftconvolve(img1_smooth, img2_smooth, mode='same')

    # Center of image
    center_y, center_x = np.array(corr.shape) // 2

    # Crop around center
    y_min = max(0, center_y - patch_radius)
    y_max = min(corr.shape[0], center_y + patch_radius + 1)
    x_min = max(0, center_x - patch_radius)
    x_max = min(corr.shape[1], center_x + patch_radius + 1)

    patch = corr[y_min:y_max, x_min:x_max]

    # Center of mass
    y_grid, x_grid = np.meshgrid(
        np.arange(y_min, y_max), np.arange(x_min, x_max), indexing='ij'
    )

    total = np.sum(patch)
    if total == 0:
        return 0.0, 0.0

    y_com = np.sum(patch * y_grid) / total
    x_com = np.sum(patch * x_grid) / total

    # Drift in pixels
    dy_px = y_com - center_y
    dx_px = x_com - center_x

    if abs(dy_px) > patch_radius or abs(dx_px) > patch_radius:
        return 0.0, 0.0

    # Convert to nm
    dy_nm = dy_px * pixel_size_nm
    dx_nm = dx_px * pixel_size_nm

    return dy_nm, dx_nm

# ============================================================================
# 4. Model Download Utilities
# ============================================================================

def ensure_models(model_names, target_root="/content/AutoDS_models", model_manifest=None):
    if model_manifest is None:
        raise ValueError("model_manifest must be provided.")

    os.makedirs(target_root, exist_ok=True)

    for m in model_names:
        cfg = model_manifest[m]
        mdir = os.path.join(target_root, m)
        need_fetch = False

        req = cfg.get("contains", [])
        if not os.path.isdir(mdir):
            need_fetch = True
        else:
            for f in req:
                if not os.path.exists(os.path.join(mdir, f)):
                    need_fetch = True
                    break

        if not need_fetch:
            print(f"[models] found: {m}")
            continue

        print(f"[models] preparing: {m}")
        os.makedirs(mdir, exist_ok=True)

        if "file_urls" in cfg:
            file_urls = cfg["file_urls"]
            for fname, url in file_urls.items():
                dst = os.path.join(mdir, fname)
                print(f"[models] downloading: {url}")
                urllib.request.urlretrieve(url, dst)
        else:
            raise ValueError(f"Model {m} manifest must have 'file_urls'.")

        for f in req:
            if not os.path.exists(os.path.join(mdir, f)):
                raise FileNotFoundError(f"Model {m} missing required file: {f}")

        print(f"[models] ready: {m}")

    return target_root


import numpy as np
import scipy.optimize as opt
from numpy.lib.stride_tricks import sliding_window_view
from scipy.ndimage import gaussian_filter, zoom
from scipy.ndimage import gaussian_laplace, maximum_filter, binary_dilation
import torch
import torch.nn.functional as F

# ============================================================================
# 1. Image Preprocessing Functions
# ============================================================================

def normalize_im_01(im):
    """Normalize image to [0, 1]"""
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val) / (max_val - min_val)

def normalize_im_01_ret_vals(im):
    """Normalize and return normalization parameters"""
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val) / (max_val - min_val), min_val, max_val

def normalize_im(im, dmean, dstd):
    """Normalize image with given mean and std"""
    im = np.squeeze(im)
    return (im - dmean) / dstd

def subtract_smooth_background(im, sigma=3):
    """Subtract smoothed background"""
    return im - gaussian_filter(im, sigma)

def remove_zero_padding(image):
    """Remove zero padding from image"""
    image_array = np.array(image)
    non_zero_rows = np.where(image_array.sum(axis=1) != 0)
    non_zero_cols = np.where(image_array.sum(axis=0) != 0)
    cropped_image = image_array[non_zero_rows[0][0]:non_zero_rows[0][-1]+1,
                                non_zero_cols[0][0]:non_zero_cols[0][-1]+1]
    return cropped_image

# ============================================================================
# 2. Patch Splitting
# ============================================================================

def split_image_to_patches(img, num_patches, overlap):
    """
    Split image into overlapping patches

    Args:
        img: Input image (H, W)
        num_patches: Number of patches per dimension
        overlap: Overlap size in pixels

    Returns:
        List of patches
    """
    H, W = img.shape
    patch_h = H // num_patches
    patch_w = W // num_patches

    # Pad image for border patches
    padded_img = np.pad(img, ((overlap, overlap), (overlap, overlap)), mode='reflect')

    # Window shape including overlap
    window_shape = (patch_h + 2 * overlap, patch_w + 2 * overlap)

    # Create sliding window view
    patches_view = sliding_window_view(padded_img, window_shape)

    # Sample at regular intervals
    patches_array = patches_view[0::patch_h, 0::patch_w, :, :]

    # Flatten to list
    num_rows, num_cols, ph, pw = patches_array.shape
    patches_list = [patches_array[i, j].copy()
                   for i in range(num_rows)
                   for j in range(num_cols)]

    return patches_list

# ============================================================================
# 3. Interpolation and Scaling
# ============================================================================

def gaussian_interpolation_batch(data_batch, scale, sigma=1):
    """Apply Gaussian interpolation to batch of images"""
    upsampled_data_batch = []

    for data in data_batch:
        smoothed_data = gaussian_filter(data, sigma=sigma)
        upsampled_data = zoom(smoothed_data, scale, order=3)
        upsampled_data_batch.append(upsampled_data)

    return np.array(upsampled_data_batch)

def interpolate_frames(tiff_stack, model_pixel_size, current_pixel_size,
                      model_wavelength, current_wavelength,
                      model_NA, current_NA):
    """Interpolate frames to match model parameters"""
    # Set defaults
    if model_pixel_size is None:
        model_pixel_size = current_pixel_size
    if model_wavelength is None:
        model_wavelength = current_wavelength
    if model_NA is None:
        model_NA = current_NA
    if current_wavelength is None:
        current_wavelength = model_wavelength = 1
    if current_NA is None:
        current_NA = model_NA = 1

    if len(tiff_stack.shape) == 2:
        tiff_stack = tiff_stack[None, :, :]

    # Compute scaling ratio based on optical parameters
    scale_ratio_sq = ((0.21 * model_wavelength / model_NA) ** 2 -
                     (0.21 * current_wavelength / current_NA) ** 2)

    if scale_ratio_sq > 0:
        scale_ratio = np.sqrt(scale_ratio_sq) / model_pixel_size
        interpolated_stack = np.stack([
            gaussian_filter(tiff_stack[i], scale_ratio)
            for i in range(tiff_stack.shape[0])
        ])
    else:
        zoom_factors = (1,
                       model_pixel_size / current_pixel_size,
                       model_pixel_size / current_pixel_size)
        interpolated_stack = zoom(tiff_stack.astype(np.float32),
                                 zoom_factors, order=3)

    return interpolated_stack.astype(np.float32, copy=False)

# ============================================================================
# 4. Feature Extraction
# ============================================================================

def gauss2d(xy, offset, amp, x0, y0, sigma):
    """2D Gaussian function for fitting"""
    x, y = xy
    return offset + (amp * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2) -
                                  ((y - y0) ** 2) / (2 * sigma ** 2)))

def extract_features_frame(OrigImage, pixel_size, psf_sigma, offset=None, verbose=False):
    """
    Extract features from a single frame

    Returns:
        ADC_offset: Mean background
        ReadOutNoise_ADC: Std of background
        Signal_amp: Mean signal amplitude
        emitter_density: Density of emitters (per μm²)
    """
    M, N = OrigImage.shape

    # Subtract smooth background
    Image = OrigImage - gaussian_filter(OrigImage, sigma=5)

    # Check if SNR is sufficient
    if offset is not None:
        if (np.percentile(gaussian_filter(Image, 2), 99) < 2 * Image.mean() or
            np.percentile(OrigImage, 99) < 2 * offset):
            if verbose:
                print("SNR too low - ignoring patch")
            return np.mean(OrigImage), np.std(OrigImage), 0, 0

    # Laplacian of Gaussian for blob detection
    log_image = -gaussian_laplace(Image, sigma=psf_sigma)

    # Local maxima filtering
    neighborhood_size = 3
    local_max = (log_image == maximum_filter(log_image, size=neighborhood_size))

    # Intensity threshold
    amp_threshold = np.mean(Image) + 0.5 * (np.percentile(Image, 99) - np.mean(Image))
    pcntl_threshold = np.percentile(Image, 85)

    # Binary mask for emitters
    binary_mask = np.logical_and(local_max,
                                 Image > np.max([amp_threshold, pcntl_threshold]))

    # Dilate and create noise mask
    dilated_mask = binary_dilation(binary_mask, structure=np.ones((5, 5)))
    noise_mask = np.ones_like(binary_mask)
    noise_mask[dilated_mask] = 0

    if np.sum(binary_mask) > 0:
        ADC_offset = np.mean(OrigImage[noise_mask])
        ReadOutNoise_ADC = np.std(OrigImage[noise_mask])
        Signal_amp = np.mean(OrigImage[binary_mask == 1])
        emitter_density = (10 ** 6) * float(np.sum(binary_mask)) / (M * N * pixel_size ** 2)
    else:
        if verbose:
            print("Didn't find any emitters")
        return np.mean(OrigImage), np.std(OrigImage), 0, 0

    # Additional SNR check
    if Signal_amp / ADC_offset < 2.5:
        if emitter_density > 2:
            if verbose:
                print("SNR too low for emitter density estimation")
            return ADC_offset, ReadOutNoise_ADC, Signal_amp, 0

    return ADC_offset, ReadOutNoise_ADC, Signal_amp, emitter_density

# ============================================================================
# 5. Model Selection
# ============================================================================

def ChooseNetByDifficulty_2025(density, SNR):
    """ Choose network based on density and SNR """
    num_models = 4
    norm_density = np.max([np.min([int(np.round(2 * density)), num_models - 1]), 0])
    norm_SNR = num_models - 1 - np.max([np.min([SNR // 2, num_models - 1]), 0])
    return int(np.round((norm_SNR + norm_density) / 2))

# ============================================================================
# Module-level kernel cache (shared across all calls)
_kernel_cache = {}

def _get_gaussian_kernel(sigma, device):
    """Generate Gaussian kernel for smoothing"""
    key = f'gauss_{sigma}_{device}'
    if key not in _kernel_cache:
        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=device)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2))
        kernel = kernel / kernel.sum()
        _kernel_cache[key] = kernel.view(1, 1, kernel_size, kernel_size)
    return _kernel_cache[key]


def _get_log_kernel(sigma, device):
    """Generate Laplacian of Gaussian kernel for blob detection"""
    key = f'log_{sigma}_{device}'
    if key not in _kernel_cache:
        kernel_size = int(2 * np.ceil(3 * sigma) + 1)
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=device)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        r2 = xx ** 2 + yy ** 2
        kernel = -(1 / (np.pi * sigma ** 4)) * (1 - r2 / (2 * sigma ** 2)) * torch.exp(-r2 / (2 * sigma ** 2))
        _kernel_cache[key] = kernel.view(1, 1, kernel_size, kernel_size)
    return _kernel_cache[key]


def percentile_batch(tensor, percentile):
    """Calculate percentile for batched tensors"""
    flat = tensor.flatten(1)
    result = torch.quantile(flat, percentile / 100.0, dim=1)
    return result

def extract_features_batch(patches_tensor, pixel_size, psf_sigma, offset_array=None,
                           verbose=False, device='cuda'):
    """Fully GPU-accelerated batch feature extraction"""
    B, H, W = patches_tensor.shape
    device = patches_tensor.device

    # Add channel dimension for conv operations: [B, 1, H, W]
    patches_4d = patches_tensor.unsqueeze(1)

    # 1. Gaussian filtering
    gauss_kernel = _get_gaussian_kernel(5, device)
    padding = gauss_kernel.shape[-1] // 2
    smooth_bg = F.conv2d(patches_4d, gauss_kernel, padding=padding)
    Image = patches_4d - smooth_bg # [B, 1, H, W]

    # 2. LoG filtering
    log_kernel = _get_log_kernel(psf_sigma, device)
    padding = log_kernel.shape[-1] // 2
    log_image = -F.conv2d(Image, log_kernel, padding=padding) # [B, 1, H, W]

    # 3. Local maxima
    local_max = F.max_pool2d(log_image, kernel_size=3, stride=1, padding=1) # [B, 1, H, W]

    # 4. Thresholding (all with size [B, 1, 1, 1])
    img_mean = Image.mean(dim=(2, 3), keepdim=True)
    img_99 = percentile_batch(Image.squeeze(1), 99).view(B, 1, 1, 1)
    img_85 = percentile_batch(Image.squeeze(1), 85).view(B, 1, 1, 1)
    threshold = torch.max(img_mean + 0.5 * (img_99 - img_mean), img_85)

    # 5. Binary masks (batch-wise)
    binary_mask = torch.logical_and(log_image == local_max, Image >= threshold)
    mask_float = binary_mask.float()
    dilated = F.max_pool2d(mask_float, kernel_size=5, stride=1, padding=2)
    noise_mask = (dilated < 0.5)

    # 6. PRE-COMPUTE SNR check data on GPU as a batch
    gauss_kernel_2 = _get_gaussian_kernel(2, device)
    padding_2 = (gauss_kernel_2.shape[-1] // 2)
    gauss_smooth = F.conv2d(Image, gauss_kernel_2, padding=padding_2)

    # Pre-compute percentiles on GPU (batch-wise)
    gauss_99 = percentile_batch(gauss_smooth.squeeze(1), 99) #[B]
    patch_99 = percentile_batch(patches_tensor, 99)  # [B]
    img_mean_flat = img_mean.squeeze()  # [B]

    # 7. Statistics on CPU
    patches_cpu = patches_tensor.cpu().numpy()
    binary_mask_cpu = binary_mask.squeeze(1).cpu().numpy()
    noise_mask_cpu = noise_mask.squeeze(1).cpu().numpy()

    # Move pre-computed values to CPU
    gauss_99_cpu = gauss_99.cpu().numpy()
    patch_99_cpu = patch_99.cpu().numpy()
    img_mean_cpu = img_mean_flat.cpu().numpy()

    results = []
    pixel_area = pixel_size * pixel_size

    for i in range(B):
        patch = patches_cpu[i]
        emitter_mask = binary_mask_cpu[i]
        noise_m = noise_mask_cpu[i]
        patch_offset = offset_array[i]

        if patch_offset is not None:
            if (gauss_99_cpu[i] < 2 * img_mean_cpu[i] or
                patch_99_cpu[i] < 2 * patch_offset):
                if verbose:
                    print(f"Patch {i}: SNR too low - ignoring patch")
                results.append((patch.mean(), patch.std(), 0.0, 0.0))
                continue

        num_emitters = emitter_mask.sum()
        if num_emitters == 0:
            if verbose:
                print(f"Patch {i}: Didn't find any emitters")
            results.append((patch.mean(), patch.std(), 0.0, 0.0))
            continue

        ADC_offset = patch[noise_m].mean()
        ReadOutNoise_ADC = patch[noise_m].std()
        Signal_amp = patch[emitter_mask].mean()
        emitter_density = 1e6 * float(num_emitters) / (H * W * pixel_area)

        # Additional SNR check
        if Signal_amp / (ADC_offset + 1e-8) < 2.5:
            if emitter_density > 2:
                if verbose:
                    print(f"Patch {i}: SNR too low for emitter density estimation")
                results.append((float(ADC_offset), float(ReadOutNoise_ADC),
                                float(Signal_amp), 0.0))
                continue

        results.append((float(ADC_offset), float(ReadOutNoise_ADC),
                        float(Signal_amp), float(emitter_density)))

    return results


def preprocess_frames_batch(frames_batch, device='cuda'):
    """GPU-accelerated batch preprocessing of frames"""
    B, H, W = frames_batch.shape

    # Calculate 35th percentile for each frame (on GPU)
    frames_flat = frames_batch.reshape(B, -1)
    p35 = torch.quantile(frames_flat, 0.35, dim=1, keepdim=True)
    p35 = p35.view(B, 1, 1)

    # Subtract 35th percentile
    frames_processed = frames_batch - p35

    # Subtract minimum
    frames_min = frames_processed.reshape(B, -1).min(dim=1, keepdim=True)[0]
    frames_min = frames_min.view(B, 1, 1)
    frames_processed = frames_processed - frames_min

    # Calculate mean and std for normalization
    frames_mean = frames_processed.reshape(B, -1).double().mean(dim=1).float()
    frames_std = frames_processed.reshape(B, -1).double().std(dim=1).float() + 1e-6
    frames_mean_batch = frames_mean.view(B, 1, 1)
    frames_std_batch = frames_std.view(B, 1, 1)

    # Normalize
    frames_processed = (frames_processed - frames_mean_batch) / frames_std_batch

    # Calculate offsets
    offsets = frames_processed.reshape(B, -1).mean(dim=1)

    return frames_processed, offsets


def interpolate_frames_batch(frames_batch, model_pixel_size, current_pixel_size,
                                  model_wavelength, current_wavelength,
                                  model_NA, current_NA, device='cuda'):
    """GPU-accelerated batch interpolation for multiple frames"""
    # Handle None values
    if model_pixel_size is None: model_pixel_size = current_pixel_size
    if model_wavelength is None: model_wavelength = current_wavelength
    if model_NA is None: model_NA = current_NA
    if current_wavelength is None: current_wavelength = model_wavelength = 1
    if current_NA is None: current_NA = model_NA = 1

    # Calculate scale ratio
    scale_ratio_sq = (0.21 * model_wavelength / model_NA) ** 2 - \
                     (0.21 * current_wavelength / current_NA) ** 2

    if scale_ratio_sq > 0:
        # Gaussian smoothing path
        scale_ratio = np.sqrt(scale_ratio_sq) / model_pixel_size
        kernel = _get_gaussian_kernel(scale_ratio, device)

        # Apply Gaussian filter to all frames at once
        frames_4d = frames_batch.unsqueeze(1)  # (B, 1, H, W)
        padding = kernel.shape[-1] // 2
        interpolated = F.conv2d(frames_4d, kernel, padding=padding).squeeze(1)
    else:
        # Zoom/resize path
        zoom_factor = model_pixel_size / current_pixel_size

        if zoom_factor != 1.0:
            # Use bilinear interpolation on GPU
            new_h = int(frames_batch.shape[1] * zoom_factor)
            new_w = int(frames_batch.shape[2] * zoom_factor)

            frames_4d = frames_batch.unsqueeze(1)
            interpolated = F.interpolate(frames_4d, size=(new_h, new_w),
                                        mode='bicubic', align_corners=False).squeeze(1)
        else:
            interpolated = frames_batch

    return interpolated

def split_image_to_patches_batch(img_batch, num_patches, overlap, device='cuda'):
    """ Split tensor of images into overlapping patches """
    # Handle both 2D and 3D input
    if img_batch.dim() == 2:
        img_batch = img_batch.unsqueeze(0)  # (H, W) -> (1, H, W)

    # Determine the non-overlapping patch size
    B, H, W = img_batch.shape
    patch_h = H // num_patches
    patch_w = W // num_patches

    # Pad image for border patches (reflection padding as in the original)
    padded = F.pad(img_batch.unsqueeze(1), # (B, 1, H, W)
                    (overlap, overlap, overlap, overlap),
                    mode='reflect').squeeze(1) # (B, H+2*overlap, W+2*overlap)

    # Calculate window shape including overlap
    window_h = patch_h + 2 * overlap
    window_w = patch_w + 2 * overlap

    # create sliding windows along height and then along width with the patch_h and patch_w as the step
    patches = padded.unfold(1, window_h, patch_h).unfold(2, window_w, patch_w)
    # Shape: (B, num_patches, num_patches, window_h, window_w)

    # Reshape to (B, num_patches * num_patches, window_h, window_w)
    # Flatten the 2D grid of patches for every frame (row-major order).
    B, num_rows, num_cols, ph, pw = patches.shape
    patches = patches.reshape(B, num_rows * num_cols, ph, pw)

    return patches

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# ============================================================================
# 1. Basic CNN Model (without upsampling)
# ============================================================================

class CNNModel(nn.Module):
    def __init__(self, in_channels=1):
        super(CNNModel, self).__init__()

        # Encoder
        self.features1 = ConvBNReLU(in_channels, 32, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.features2 = ConvBNReLU(32, 64, 3)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.features3 = ConvBNReLU(64, 128, 3)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.features4 = ConvBNReLU(128, 512, 3)

        # Decoder
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features5 = ConvBNReLU(512, 128, 3)

        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features6 = ConvBNReLU(128, 64, 3)

        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.features7 = ConvBNReLU(64, 32, 3)

        # Prediction head
        self.prediction = nn.Conv2d(32, 1, 1, stride=1, padding=0, bias=False)
        nn.init.orthogonal_(self.prediction.weight)

    def forward(self, x):
        # Encoder
        x = self.features1(x)
        x = self.pool1(x)

        x = self.features2(x)
        x = self.pool2(x)

        x = self.features3(x)
        x = self.pool3(x)

        # Bottleneck
        x = self.features4(x)

        # Decoder
        x = self.upsample1(x)
        x = self.features5(x)

        x = self.upsample2(x)
        x = self.features6(x)

        x = self.upsample3(x)
        x = self.features7(x)

        # Prediction
        x = self.prediction(x)
        return x


# ============================================================================
# 2. CNN Building Blocks - optimized with fused Conv+BN+ReL operations
# ============================================================================

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None):
        super(ConvBNReLU, self).__init__()

        if padding is None:
            padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                              stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Initialize with Orthogonal (similar to Keras)
        nn.init.orthogonal_(self.conv.weight)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


# ============================================================================
# 3. CNN Model with Upsampling - optimized with fused Conv+BN+ReL
# ============================================================================

class CNNUpsample(nn.Module):
    def __init__(self, in_channels=1, upsampling_factor=8):
        super(CNNUpsample, self).__init__()
        self.upsampling_factor = upsampling_factor

        # Encoder with fused blocks
        self.conv_bn_relu1 = ConvBNReLU(in_channels, 32, 3, 1)
        self.conv_bn_relu2 = ConvBNReLU(32, 64, 3, 1)
        self.conv_bn_relu3 = ConvBNReLU(64, 128, 3, 1)
        self.conv_bn_relu4 = ConvBNReLU(128, 256, 3, 1)

        # Decoder with fused blocks
        self.conv_bn_relu5 = ConvBNReLU(256, 128, 3, 1)
        self.conv_bn_relu6 = ConvBNReLU(128, 64, 3, 1)

        # OPTIMIZED: Upsampling blocks with 3x3 kernels + fused Conv+BN+ReLU
        num_upsample_blocks = int(np.log2(upsampling_factor))
        self.upsample_blocks = nn.ModuleList()

        for i in range(num_upsample_blocks):
            in_ch = 64 if i == 0 else 32
            block = nn.ModuleDict({
                'upsample': nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                'conv_bn_relu': ConvBNReLU(in_ch, 32, 5, 1)
            })
            self.upsample_blocks.append(block)

        # Prediction head
        self.prediction = nn.Conv2d(32, 1, 1, stride=1, padding=0, bias=False)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        # Encoder
        x = self.conv_bn_relu1(x)
        x = self.conv_bn_relu2(x)
        x = self.conv_bn_relu3(x)
        x = self.conv_bn_relu4(x)

        # Decoder
        x = self.conv_bn_relu5(x)
        x = self.conv_bn_relu6(x)

        # Upsampling
        for block in self.upsample_blocks:
            x = block['upsample'](x)
            x = block['conv_bn_relu'](x)

        # Prediction
        x = self.prediction(x)
        return x


# ============================================================================
# 1. Gaussian Filter for Loss Computation
# ============================================================================

def matlab_style_gauss2D(shape=(7, 7), sigma=1):
    """Create 2D Gaussian kernel matching MATLAB style"""
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1, -n:n+1]
    h = np.exp(-(x*x + y*y) / (2. * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h = h * 2.0
    return h.astype(np.float32)

# Create Gaussian filter as a tensor
psf_heatmap = matlab_style_gauss2D(shape=(7, 7), sigma=1)
# Shape: [out_channels, in_channels, height, width] -> [1, 1, 7, 7]
gfilter = torch.from_numpy(psf_heatmap).view(1, 1, 7, 7)

# ============================================================================
# 2. Custom Loss Functions
# ============================================================================

class L1L2Loss(nn.Module):
    """Combined L1 + L2 loss with Gaussian filtering"""
    def __init__(self, input_shape):
        super(L1L2Loss, self).__init__()
        self.input_shape = input_shape
        # Register Gaussian filter as buffer (moves with model to GPU)
        self.register_buffer('gfilter', gfilter)

    def forward(self, spikes_pred, heatmap_true):
        # Apply Gaussian convolution to predictions
        heatmap_pred = F.conv2d(spikes_pred, self.gfilter, padding=3)

        # MSE loss on heatmaps
        loss_heatmaps = F.mse_loss(heatmap_pred, heatmap_true)

        # L1 loss on spikes (sparsity)
        loss_spikes = torch.mean(torch.abs(spikes_pred))

        return loss_heatmaps + loss_spikes

class CustomLoss(nn.Module):
    """Custom loss for upsampling model"""
    def __init__(self, input_shape):
        super(CustomLoss, self).__init__()
        self.input_shape = input_shape
        self.register_buffer('gfilter', gfilter)

    def forward(self, y_pred, y_true):
        # Apply Gaussian convolution
        heatmap_pred = F.conv2d(y_pred, self.gfilter, padding=3)

        # MSE on heatmaps
        loss_heatmaps = torch.mean((y_true - heatmap_pred) ** 2)

        # L1 on predictions (sparsity)
        loss_spikes = torch.mean(torch.abs(y_pred))

        return loss_heatmaps + loss_spikes

# ============================================================================
# 3. Maxima Finder Layer (Peak Detection)
# ============================================================================

class MaximaFinder(nn.Module):
    """Find local maxima in predicted density maps"""
    def __init__(self, thresh=0.1, neighborhood_size=3, use_local_avg=False):
        super(MaximaFinder, self).__init__()
        self.thresh = thresh
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

        if use_local_avg:
            # Sobel-like kernels for local averaging
            kernel_x = torch.tensor([[[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_y = torch.tensor([[[-1, -1, -1], [0, 0, 0], [1, 1, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_sum = torch.ones(1, 1, 3, 3, dtype=torch.float32)

            self.register_buffer('kernel_x', kernel_x)
            self.register_buffer('kernel_y', kernel_y)
            self.register_buffer('kernel_sum', kernel_sum)

    def forward(self, inputs):
        # Max pooling to find local maxima
        max_pool = F.max_pool2d(inputs, kernel_size=self.nhood,
                               stride=1, padding=self.nhood//2)

        # Condition: value is local max AND above threshold
        cond = (max_pool > self.thresh) & (max_pool == inputs)

        # Get indices where condition is True
        indices = torch.nonzero(cond, as_tuple=False)  # (N, 4): [batch, channel, y, x]

        bind = indices[:, 0]  # batch indices
        yind = indices[:, 2]  # y coordinates
        xind = indices[:, 3]  # x coordinates

        # Gather confidence values
        confidence = inputs[bind, indices[:, 1], yind, xind]

        # Convert to float for potential subpixel refinement
        xind = xind.float()
        yind = yind.float()

        # Subpixel refinement using local averaging
        if self.use_local_avg:
            # Ensure kernels match input dtype
            kernel_x = self.kernel_x.to(inputs.dtype)
            kernel_y = self.kernel_y.to(inputs.dtype)
            kernel_sum = self.kernel_sum.to(dtype=inputs.dtype)

            # Compute gradients
            # Sobel-like kernels for local averaging
            x_image = F.conv2d(inputs, kernel_x, padding=1)
            y_image = F.conv2d(inputs, kernel_y, padding=1)
            sum_image = F.conv2d(inputs, kernel_sum, padding=1)

            # Gather at detected locations
            gathered_sum = sum_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_x = x_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_y = y_image[bind, indices[:, 1], yind.long(), xind.long()]

            # Compute local offsets
            x_local = gathered_x / (gathered_sum + 1e-6)
            y_local = gathered_y / (gathered_sum + 1e-6)

            # Update positions and confidence
            xind = xind + x_local
            yind = yind + y_local
            confidence = gathered_sum

        return bind, xind, yind, confidence

# ============================================================================
# 6. Maxima Finder Layer (Peak Detection)
# ============================================================================

class MaximaFinder(nn.Module):
    """Find local maxima in predicted density maps"""
    def __init__(self, thresh=0.1, neighborhood_size=3, use_local_avg=False):
        super(MaximaFinder, self).__init__()
        self.thresh = thresh
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

        if use_local_avg:
            # Sobel-like kernels for local averaging
            kernel_x = torch.tensor([[[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_y = torch.tensor([[[-1, -1, -1], [0, 0, 0], [1, 1, 1]]],
                                    dtype=torch.float32).view(1, 1, 3, 3)
            kernel_sum = torch.ones(1, 1, 3, 3, dtype=torch.float32)

            self.register_buffer('kernel_x', kernel_x)
            self.register_buffer('kernel_y', kernel_y)
            self.register_buffer('kernel_sum', kernel_sum)

    def forward(self, inputs):
        # Max pooling to find local maxima
        max_pool = F.max_pool2d(inputs, kernel_size=self.nhood,
                               stride=1, padding=self.nhood//2)

        # Condition: value is local max AND above threshold
        cond = (max_pool > self.thresh) & (max_pool == inputs)

        # Get indices where condition is True
        indices = torch.nonzero(cond, as_tuple=False)  # (N, 4): [batch, channel, y, x]

        bind = indices[:, 0]  # batch indices
        yind = indices[:, 2]  # y coordinates
        xind = indices[:, 3]  # x coordinates

        # Gather confidence values
        confidence = inputs[bind, indices[:, 1], yind, xind]

        # Convert to float for potential subpixel refinement
        xind = xind.float()
        yind = yind.float()

        # Subpixel refinement using local averaging
        if self.use_local_avg:
            # Ensure kernels match input dtype
            kernel_x = self.kernel_x.to(inputs.dtype)
            kernel_y = self.kernel_y.to(inputs.dtype)
            kernel_sum = self.kernel_sum.to(dtype=inputs.dtype)

            # Compute gradients
            x_image = F.conv2d(inputs, kernel_x, padding=1)
            y_image = F.conv2d(inputs, kernel_y, padding=1)
            sum_image = F.conv2d(inputs, kernel_sum, padding=1)

            # Gather at detected locations
            gathered_sum = sum_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_x = x_image[bind, indices[:, 1], yind.long(), xind.long()]
            gathered_y = y_image[bind, indices[:, 1], yind.long(), xind.long()]

            # Compute local offsets
            x_local = gathered_x / (gathered_sum + 1e-6)
            y_local = gathered_y / (gathered_sum + 1e-6)

            # Update positions and confidence
            xind = xind + x_local
            yind = yind + y_local
            confidence = gathered_sum

        return bind, xind, yind, confidence

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import time
import h5py
import re
import os



# ============================================================================
# 1. Model Builder Function
# ============================================================================

def build_model_upsample(input_shape, lr=0.001, upsampling_factor=8):
    """
    Build upsampling model for PyTorch

    Args:
        input_shape: Tuple (H, W, C) - note: will be converted to (C, H, W)
        lr: Learning rate
        upsampling_factor: Upsampling factor

    Returns:
        model: PyTorch model
        optimizer: Adam optimizer
        criterion: Loss function
    """

    # Convert from (H, W, C) to (C, H, W)
    in_channels = input_shape[2] if len(input_shape) == 3 else 1

    model = CNNUpsample(in_channels=in_channels,
                        upsampling_factor=upsampling_factor)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = CustomLoss(input_shape)

    return model, optimizer, criterion

# ============================================================================
# 2. Weight Loading - Support both PyTorch and Keras formats
# ============================================================================

def load_model_weights(model, weights_path, verbose=True):
    """
    Load model weights from either PyTorch (.pth) or Keras (.h5) format

    Args:
        model: PyTorch model
        weights_path: Path to weights file (.pth or .h5)
        verbose: Print loading progress
    """
    if weights_path.endswith('.pth'):
        load_pytorch_weights(model, weights_path, verbose=verbose)
    elif weights_path.endswith('.h5'):
        load_keras_weights_to_pytorch(model, weights_path, verbose=verbose)
    else:
        raise ValueError(f"Unsupported weights format: {weights_path}. "
                        f"Expected .pth or .h5 file")


def load_pytorch_weights(model, pth_path, verbose=True):
    """
    Load PyTorch native weights from .pth file

    Args:
        model: PyTorch model
        pth_path: Path to .pth weights file
        verbose: Print loading progress
    """
    if verbose:
        print(f"Loading PyTorch weights from {pth_path}")

    # Get device from model
    device = next(model.parameters()).device

    # Load state dict
    state_dict = torch.load(pth_path, map_location=device)

    # Load weights into model
    model.load_state_dict(state_dict)

    if verbose:
        print("✓ PyTorch weights loaded successfully!")


def load_keras_weights_to_pytorch(model, h5_path, verbose=True):
    """
    Load Keras weights from H5 file to PyTorch model

    Supports fused Conv+BN+ReLU blocks while maintaining compatibility.

    Args:
        model: PyTorch model (CNNUpsample with fused blocks)
        h5_path: Path to Keras H5 weights file
        verbose: Print loading progress
    """
    if verbose:
        print(f"Loading Keras weights from {h5_path}")

    # Get device from model
    device = next(model.parameters()).device

    with h5py.File(h5_path, 'r') as f:
        # Get all layer names from the H5 file
        if 'model_weights' in f:
            weight_group = f['model_weights']
        else:
            weight_group = f

        # Extract layer names
        if hasattr(weight_group, 'attrs') and 'layer_names' in weight_group.attrs:
            layer_names = [n.decode('utf8') if isinstance(n, bytes) else n
                           for n in weight_group.attrs['layer_names']]
        else:
            layer_names = list(weight_group.keys())

        if verbose:
            print(f"Found {len(layer_names)} layers in H5 file")

        # Create a dictionary to store weights
        keras_weights = {}

        for layer_name in layer_names:
            if layer_name not in weight_group:
                continue

            layer_group = weight_group[layer_name]

            if not hasattr(layer_group, 'keys'):
                continue

            # Get weight names for this layer
            if hasattr(layer_group, 'attrs') and 'weight_names' in layer_group.attrs:
                weight_names = [n.decode('utf8') if isinstance(n, bytes) else n
                                for n in layer_group.attrs['weight_names']]
            else:
                weight_names = list(layer_group.keys())

            # Extract weights
            layer_weights = {}
            for weight_name in weight_names:
                if '/' in weight_name:
                    weight_key = weight_name.split('/')[-1]
                else:
                    weight_key = weight_name

                try:
                    weight_value = layer_group[weight_name][()]
                    layer_weights[weight_key] = weight_value
                except:
                    try:
                        weight_value = layer_group[weight_key][()]
                        layer_weights[weight_key] = weight_value
                    except:
                        if verbose:
                            print(f"  Warning: Could not load {weight_name} from {layer_name}")

            if layer_weights:
                keras_weights[layer_name] = layer_weights

        if verbose:
            print(f"Extracted weights from {len(keras_weights)} layers")

        # Assign to PyTorch model with fused blocks
        _assign_weights_to_model(model, keras_weights, device, verbose=verbose)

    if verbose:
        print("✓ Keras weights loaded successfully!")


def _assign_weights_to_model(model, keras_weights, device, verbose=True):
    """Helper function to assign Keras weights to PyTorch model with fused blocks"""

    # Mapping from Keras layer names to PyTorch fused block names
    name_mapping = {
        'F1': 'conv_bn_relu1',
        'BN_1': 'conv_bn_relu1',
        'F2': 'conv_bn_relu2',
        'BN_2': 'conv_bn_relu2',
        'F3': 'conv_bn_relu3',
        'BN_3': 'conv_bn_relu3',
        'F4': 'conv_bn_relu4',
        'BN_4': 'conv_bn_relu4',
        'F5': 'conv_bn_relu5',
        'BN_5': 'conv_bn_relu5',
        'F6': 'conv_bn_relu6',
        'BN_6': 'conv_bn_relu6',
        'Prediction': 'prediction',
    }

    model_dict = dict(model.named_modules())
    loaded_count = 0

    # Load encoder and decoder layers (now fused blocks)
    for keras_name, pytorch_name in name_mapping.items():
        if keras_name not in keras_weights:
            continue

        if pytorch_name not in model_dict:
            continue

        module = model_dict[pytorch_name]
        weights = keras_weights[keras_name]

        # Check if this is a fused ConvBNReLU block
        if hasattr(module, 'conv') and hasattr(module, 'bn'):
            # This is a fused block - load into conv and bn sub-modules

            # Load Conv2d weights
            if 'kernel:0' in weights:
                kernel = weights['kernel:0']
                kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                module.conv.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                loaded_count += 1
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name}.conv (Conv2d)")

            # Load BatchNorm weights
            if 'gamma:0' in weights:
                module.bn.weight.data = torch.from_numpy(weights['gamma:0']).float().to(device)
            if 'beta:0' in weights:
                module.bn.bias.data = torch.from_numpy(weights['beta:0']).float().to(device)
            if 'moving_mean:0' in weights:
                module.bn.running_mean.data = torch.from_numpy(weights['moving_mean:0']).float().to(device)
            if 'moving_variance:0' in weights:
                module.bn.running_var.data = torch.from_numpy(weights['moving_variance:0']).float().to(device)

            if any(k in weights for k in ['gamma:0', 'beta:0']):
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name}.bn (BatchNorm)")

        # Load prediction layer (not fused)
        elif isinstance(module, nn.Conv2d):
            if 'kernel:0' in weights:
                kernel = weights['kernel:0']
                kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                module.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                loaded_count += 1
                if verbose:
                    print(f"  ✓ Loaded {keras_name} -> {pytorch_name} (Conv2d)")

            if 'bias:0' in weights and module.bias is not None:
                bias = weights['bias:0']
                module.bias.data = torch.from_numpy(bias).float().to(device)

    # Load upsampling blocks (now with fused conv_bn_relu)
    for keras_name in keras_weights.keys():
        if 'conv_upsample' in keras_name or 'BN_upsample' in keras_name:
            match = re.search(r'(\d+)', keras_name)
            if match:
                idx = int(match.group(1)) - 1

                if idx >= len(model.upsample_blocks):
                    continue

                weights = keras_weights[keras_name]

                if 'conv_upsample' in keras_name:
                    # Access the fused block's conv layer
                    fused_block = model.upsample_blocks[idx]['conv_bn_relu']

                    if 'kernel:0' in weights and hasattr(fused_block, 'conv'):
                        kernel = weights['kernel:0']
                        kernel_torch = np.transpose(kernel, (3, 2, 0, 1))
                        fused_block.conv.weight.data = torch.from_numpy(kernel_torch).float().to(device)
                        loaded_count += 1
                        if verbose:
                            print(f"  ✓ Loaded {keras_name} -> upsample_blocks[{idx}]['conv_bn_relu'].conv")

                elif 'BN_upsample' in keras_name:
                    # Access the fused block's bn layer
                    fused_block = model.upsample_blocks[idx]['conv_bn_relu']

                    if hasattr(fused_block, 'bn'):
                        if 'gamma:0' in weights:
                            fused_block.bn.weight.data = torch.from_numpy(weights['gamma:0']).float().to(device)
                        if 'beta:0' in weights:
                            fused_block.bn.bias.data = torch.from_numpy(weights['beta:0']).float().to(device)
                        if 'moving_mean:0' in weights:
                            fused_block.bn.running_mean.data = torch.from_numpy(weights['moving_mean:0']).float().to(
                                device)
                        if 'moving_variance:0' in weights:
                            fused_block.bn.running_var.data = torch.from_numpy(weights['moving_variance:0']).float().to(
                                device)
                        loaded_count += 1
                        if verbose:
                            print(f"  ✓ Loaded {keras_name} -> upsample_blocks[{idx}]['conv_bn_relu'].bn")

    if verbose:
        print(f"\n✓ Successfully loaded {loaded_count} layer weights")

# ============================================================================
# 3. Main Reconstruction Function with Global Profiling
# ============================================================================

def reconstruct_patches_2025_pytorch(
        patches, patch_indices, frame_numbers,
        model_num,
        num_patches, overlap,
        number_of_frames, threshold, neighborhood_size=3,
        use_local_avg=True, upsampling_factor=8,
        pixel_size=233, batch_size=32, L2_weighting_factor=100,
        profiler=None):

    profiler.start_timer("reconstruction.total")

    pixel_size_hr = pixel_size / upsampling_factor

    # Convert patches to float32
    device = get_device()
    #patches = torch.stack(patches).float().to(device)
    patches = patches.float().to(device)

    if patches.ndim == 2:
        patches = patches.unsqueeze(0)  # Ensure 3D shape
    K_frames, M, N = patches.shape

    # Determine dimensions of each predicted (cropped) patch
    upsampled_patch_h = M * upsampling_factor - 2 * overlap
    upsampled_patch_w = N * upsampling_factor - 2 * overlap

    # Create full image tensor on GPU
    reconstructed_image = torch.zeros((upsampled_patch_h * num_patches, upsampled_patch_w * num_patches), dtype=torch.float32, device=device)

    # Prepare lists for detections
    recon_xind, recon_yind, frame_index, confidence_list = [], [], [], []

    with torch.cuda.device(0):
        # Get model from cache
        model = get_model(model_num)
        model.eval()

        # Create the post-processing layer
        profiler.start_timer("reconstruction.maxima_finder_init")
        max_layer = MaximaFinder(threshold, neighborhood_size, use_local_avg).to(device)
        profiler.stop_timer("reconstruction.maxima_finder_init")

        # Process in batches
        n_batches = int(np.ceil(K_frames / batch_size))

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(K_frames, start_idx + batch_size)
            nF = end_idx - start_idx

            # --- Move input batch to GPU ---
            batch_imgs = patches[start_idx:end_idx].to(device)  # Shape: (nF, M, N)

            # add channel dim to match conv2D
            batch_imgs = batch_imgs.unsqueeze(1) # Shape: (nF, 1, M, N)

            # --- Run prediction on GPU ---
            profiler.start_timer("reconstruction.model_forward_pass")
            # Enables Automatic Mixed Precision (AMP) for CUDA (use float16 instead of float32)
            with torch.no_grad():
                with torch.amp.autocast('cuda'):
                    predicted_density = model(batch_imgs)
            profiler.stop_timer("reconstruction.model_forward_pass")

            # Post-processing
            predicted_density = torch.relu(predicted_density - 0.5)  # Faster than `predicted_density[predicted_density < 0] = 0`

            # Crop off extra overlap
            cropped_pred = predicted_density[:, 0, overlap:-overlap, overlap:-overlap]

            # --- Post-processing on GPU ---
            # Maxima detection
            profiler.start_timer("reconstruction.maxima_detection")
            bind, xind, yind, conf = max_layer(predicted_density[:, :, overlap:-overlap, overlap:-overlap])

            # Convert tensors to NumPy (only when needed)
            bind_np = bind.cpu().numpy()
            xind_np = xind.cpu().numpy()
            yind_np = yind.cpu().numpy()
            conf_np = conf.cpu().numpy() / L2_weighting_factor
            profiler.stop_timer("reconstruction.maxima_detection")

            profiler.start_timer("reconstruction.reconstruct_image")
            # --- Place each patch in reconstructed image ---
            for i in range(nF):
                p_ind = patch_indices[start_idx + i]
                y1 = upsampled_patch_h * (p_ind // num_patches)
                x1 = upsampled_patch_w * (p_ind % num_patches)

                # Use PyTorch addition instead of NumPy
                reconstructed_image[y1:y1 + upsampled_patch_h,
                    x1:x1 + upsampled_patch_w].add_(cropped_pred[i] / number_of_frames)

                # Collect detections (CPU operations)
                det_idx = np.where(bind_np == i)[0]
                if det_idx.size:
                    recon_xind.extend((x1 + xind_np[det_idx]).tolist())
                    recon_yind.extend((y1 + yind_np[det_idx]).tolist())
                    frame_index.extend([frame_numbers[start_idx + i] + 1] * det_idx.size)
                    confidence_list.extend(conf_np[det_idx].tolist())

            profiler.stop_timer("reconstruction.reconstruct_image")

    # Convert coordinates to physical units
    xind_final = (np.array(recon_xind) * pixel_size_hr).tolist()
    yind_final = (np.array(recon_yind) * pixel_size_hr).tolist()


    profiler.stop_timer("reconstruction.total")

    return reconstructed_image, [frame_index, xind_final, yind_final, confidence_list]


# ============================================================================
# 4. Weight Validation Function
# ============================================================================

def validate_model_weights(model, verbose=True):
    """
    Validate that model weights are loaded correctly

    Args:
        model: PyTorch model with loaded weights
        verbose: Print validation details

    Returns:
        bool: True if weights appear valid
    """
    if verbose:
        print("\n" + "=" * 70)
        print("VALIDATING MODEL WEIGHTS")
        print("=" * 70)

    issues = []

    # Check encoder/decoder fused blocks
    for i in range(1, 7):
        block_name = f'conv_bn_relu{i}'
        if hasattr(model, block_name):
            block = getattr(model, block_name)

            # Check conv weights
            conv_weights = block.conv.weight.data
            if torch.all(conv_weights == 0):
                issues.append(f"{block_name}.conv weights are all zeros")
            elif torch.isnan(conv_weights).any():
                issues.append(f"{block_name}.conv weights contain NaN")

            # Check BN parameters
            if torch.all(block.bn.weight.data == 1) and torch.all(block.bn.bias.data == 0):
                issues.append(f"{block_name}.bn parameters are uninitialized (gamma=1, beta=0)")

            if verbose:
                print(f"  {block_name}.conv: shape={tuple(conv_weights.shape)}, "
                      f"mean={conv_weights.mean().item():.6f}, std={conv_weights.std().item():.6f}")
                print(f"  {block_name}.bn: gamma_mean={block.bn.weight.mean().item():.6f}, "
                      f"beta_mean={block.bn.bias.mean().item():.6f}")

    # Check upsampling blocks
    if verbose:
        print(f"\n  Upsampling blocks: {len(model.upsample_blocks)} blocks")

    for idx, block_dict in enumerate(model.upsample_blocks):
        fused_block = block_dict['conv_bn_relu']

        conv_weights = fused_block.conv.weight.data
        expected_kernel_size = 5
        actual_kernel_size = conv_weights.shape[2]

        if actual_kernel_size != expected_kernel_size:
            issues.append(f"upsample_blocks[{idx}] has {actual_kernel_size}x{actual_kernel_size} kernel, "
                         f"expected {expected_kernel_size}x{expected_kernel_size}")

        if torch.all(conv_weights == 0):
            issues.append(f"upsample_blocks[{idx}].conv weights are all zeros")

        if verbose:
            print(f"  upsample_blocks[{idx}].conv: shape={tuple(conv_weights.shape)}, "
                  f"kernel_size={actual_kernel_size}x{actual_kernel_size}, "
                  f"mean={conv_weights.mean().item():.6f}")

    # Check prediction layer
    pred_weights = model.prediction.weight.data
    if torch.all(pred_weights == 0):
        issues.append("prediction layer weights are all zeros")

    if verbose:
        print(f"\n  prediction: shape={tuple(pred_weights.shape)}, "
              f"mean={pred_weights.mean().item():.6f}")

    # Report results
    if verbose:
        print("\n" + "=" * 70)
        if issues:
            print("⚠️ VALIDATION WARNINGS:")
            for issue in issues:
                print(f"  - {issue}")
        else:
            print("✓ ALL WEIGHTS VALIDATED SUCCESSFULLY")
        print("=" * 70)

    return len(issues) == 0

import torch
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
import csv
from tqdm import tqdm
from collections import defaultdict

# ============================================================================
# Configuration
# ============================================================================
class Config:
    """Configuration for AutoDS inference"""
    # Data paths
    Data_folder = Data_folder
    Result_folder = "/content/gdrive/MyDrive/AutoDS/Results/TOM20_10nM/V2"  #@param {type:"string"}

    # --- Quiet/Preview flags ------------------------------------------------------
    QUIET = False  # no training/inference chatter unless set to False
    HEADLESS_PREVIEW = True  # set True if you want to see the preview figures

    # Detection parameters
    threshold = 10 #@param {type:"number"}
    neighborhood_size = 3 #@param {type:"integer"}
    use_local_average = True #@param {type:"boolean"}

    # Patch parameters
    num_patches = 8 #@param {type:"number"}
    overlap = 4 #@param {type:"number"}
    patch_batch_size = 32 #@param {type:"number"}
    frame_batch_size = 10  #@param {type:"number"}

    interpolate_based_on_imaging_parameters = True #@param {type:"boolean"}
    get_pixel_size_from_file = False #@param {type:"boolean"}
    pixel_size = 233 #@param {type:"number"}
    wavelength = 233 #@param {type:"number"}
    numerical_aperture = 1.49 #@param {type:"number"}

    chunk_size = 10000 #@param {type:"number"}

    # Timing parameters
    enable_timing = True  # Set to True to enable detailed timing profiling

    # Model parameters
    use_pytorch_weights = False  # Set to True to use .pth weights, False to use .h5 weights

    # Model paths
    prediction_model_path = "/content/AutoDS_models"
    model_names = ['diff_1', 'diff_2', 'diff_3', 'diff_4']

    # Model manifest for downloading
    MODEL_MANIFEST = {
        "diff_1": {
            "file_urls": {
                # Add your PyTorch weights URL here if using use_pytorch_weights=True:
                # "best_weights.pth": "https://your-url/diff_1/best_weights.pth",
                # Or keep Keras weights if using use_pytorch_weights=False:
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_1/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]  # Weights file checked based on use_pytorch_weights
        },
        "diff_2": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_2/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_2/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
        "diff_3": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_3/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_3/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
        "diff_4": {
            "file_urls": {
                # "best_weights.pth": "https://your-url/diff_4/best_weights.pth",
                "best_weights.h5": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/best_weights.h5",
                "model_metadata.mat": "https://github.com/alonsaguy/One-click-image-reconstruction-in-single-molecule-localization-microscopy-via-deep-learning/raw/main/AutoDS/models/diff_4/model_metadata.mat",
            },
            "contains": ["model_metadata.mat"]
        },
    }


import torch
import torch.nn as nn
import os


# ============================================================================
# Global state - simple module-level variables
# ============================================================================

_models = {}  # Dictionary to store loaded models
_device = None  # Device (CPU or CUDA)
_is_initialized = False  # Track if we've loaded models


# ============================================================================
# Simple functions to manage the cache
# ============================================================================

def initialize_model_cache(config, upsampling_factor, device=None, use_pytorch_weights=False):
    """
    Load all models once and store them in memory

    Args:
        config: Configuration object with model paths and names
        upsampling_factor: Upsampling factor for the models
        device: Device to load models on (CPU or CUDA)
        use_pytorch_weights: If True, load .pth weights. If False, load .h5 weights
    """
    global _models, _device, _is_initialized

    # Skip if already initialized
    if _is_initialized:
        print("⚠️ Models already loaded, skipping initialization")
        return

    # Setup device
    if device is None:
        _device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        _device = device

    print(f"Weight Format: {'PyTorch (.pth)' if use_pytorch_weights else 'Keras (.h5)'}")

    # Determine weight file extension
    weight_extension = 'best_weights.pth' if use_pytorch_weights else 'best_weights.h5'
    weight_type = "PyTorch" if use_pytorch_weights else "Keras"

    # Load each model
    for model_num, model_name in enumerate(config.model_names):
        model_path = os.path.join(
            config.prediction_model_path,
            model_name,
            weight_extension
        )

        if not os.path.exists(model_path):
            raise FileNotFoundError(
                f"Model weights not found: {model_path}\n"
                f"Expected {weight_type} weights for model: {model_name}"
            )

        print(f"\nLoading model {model_num + 1}/{len(config.model_names)}: {model_name} ({weight_type} weights)")

        # Create model
        model = CNNUpsample(in_channels=1, upsampling_factor=upsampling_factor)
        model = model.to(_device)
        model.eval()

        # Load weights
        load_model_weights(model, model_path, verbose=False)

        # Optimize for inference
        if torch.cuda.is_available():
            try:
                model = model.to(memory_format=torch.channels_last)
            except:
                pass

        # Store in cache
        _models[model_num] = model

        # Print memory usage
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**2
            print(f"  ✓ Loaded {model_name} (GPU Memory: {allocated:.1f} MB)")
        else:
            print(f"  ✓ Loaded {model_name}")

    _is_initialized = True

    if torch.cuda.is_available():
        total_allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"Total GPU Memory Used: {total_allocated:.1f} MB")
        print("=" * 70)


def get_model(model_num):
    """Get a cached model by its number"""
    if not _is_initialized:
        raise RuntimeError("Models not loaded. Call initialize_model_cache() first.")

    if model_num not in _models:
        raise KeyError(f"Model {model_num} not found. Available: {list(_models.keys())}")

    return _models[model_num]


def get_device():
    """Get the device being used (CPU or CUDA)"""
    if _device is None:
        raise RuntimeError("Model cache not initialized. Call initialize_model_cache() first.")

    return _device


def clear_cache():
    """Clear all cached models from memory"""
    global _models, _device, _is_initialized

    print("\n⚠️ Clearing model cache...")

    # Move models to CPU and delete
    for model in _models.values():
        model.cpu()

    _models.clear()
    _device = None
    _is_initialized = False

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("✓ Model cache cleared")


# ============================================================================
# Compatibility wrapper (optional - for backwards compatibility)
# ============================================================================

class ModelCacheManager:
    """Simple wrapper class for compatibility with existing code"""

    def initialize(self, config, upsampling_factor, device=None, use_pytorch_weights=False):
        initialize_model_cache(config, upsampling_factor, device, use_pytorch_weights)

    def get_model(self, model_num):
        return get_model(model_num)

    def get_device(self):
        return get_device()

    def clear_cache(self):
        clear_cache()


def get_model_cache():
    """Return a simple manager instance for compatibility"""
    return ModelCacheManager()


# Setup
config = Config()

# ============================================================================
# Entry Point
# ============================================================================
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if device.type != 'cuda':
        log('You do not have GPU access.')
        log('Did you change your runtime?')
        log('If the runtime settings are correct then GPU might not be allocated to your session.')
        log('Expect slow performance. To access GPU try reconnecting later.')
    else:
        log('You have GPU access')
        log('PyTorch version is ' + str(torch.__version__)+'\n')

    # Initialize timing profiler
    profiler = timing_profiler(enabled=config.enable_timing)

    config.prediction_model_path = ensure_models(config.model_names, target_root=config.prediction_model_path,
                                                 model_manifest=config.MODEL_MANIFEST)

    MAX_FILE_GB = 5.0  # warn & skip when file is larger than this

    # PSF parameters
    psf_sigma_nm = 0.21 * config.wavelength / config.numerical_aperture
    psf_sigma_pixels = psf_sigma_nm / config.pixel_size

    if config.get_pixel_size_from_file:
        pixel_size = None

    # Load model metadata
    matfile = sio.loadmat(os.path.join(config.prediction_model_path, config.model_names[0], 'model_metadata.mat'))
    try:
        model_wavelength = np.array(matfile['wavelength'].item())
    except:
        model_wavelength = None
    try:
        model_NA = np.array(matfile['numerical_aperture'].item())
    except:
        model_NA = None
    try:
        model_pixel_size = np.array(matfile['pixel_size'].item())
    except:
        model_pixel_size = None

    if os.path.isdir(config.Data_folder):
        # iterate both TIFF and ND2
        for filename in list_files_multi(config.Data_folder, extensions=['tif', 'tiff', 'nd2']):

            # Install nd2 reader only when needed (optional)
            if filename.lower().endswith('.nd2'):
                try:
                    import nd2  # already installed?
                except Exception:
                    # Jupyter-only magic; remove if not on Colab
                    # get_ipython().system('pip install -q nd2')
                    import subprocess
                    import sys

                    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'nd2'])
                    import nd2  # now import it

            in_path = os.path.join(config.Data_folder, filename)

            # --------- file size guard ----------
            try:
                file_size_gb = os.path.getsize(in_path) / 1e9
                if file_size_gb > MAX_FILE_GB:
                    print(f"\n⚠️  {filename}: {file_size_gb:.2f} GB > {MAX_FILE_GB:.2f} GB.")
                    print("   Video size is too big, please use Google Colab Pro or run locally.")
                    continue
            except Exception:
                pass

            # --- Resolve pixel size if requested ---
            if config.get_pixel_size_from_file:
                if is_tiff(in_path):
                    with catch_oom("reading TIFF pixel size", filename):
                        pixel_size, _, _ = getPixelSizeTIFFmetadata(in_path, True)
                elif is_nd2(in_path):
                    with catch_oom("reading ND2 pixel size", filename):
                        px_nm, _, _ = getPixelSizeND2metadata(in_path, True)
                        pixel_size = px_nm if px_nm is not None else pixel_size  # leave unchanged if unknown

            # --- Common model params ---
            upsampling_factor = np.array(matfile['upsampling_factor']).item()
            try:
                L2_weighting_factor = np.array(matfile['Normalization factor']).item()
            except:
                L2_weighting_factor = 100

            # save all models to cache
            initialize_model_cache(config, upsampling_factor, device,
                                   use_pytorch_weights=config.use_pytorch_weights)

            # --- Choose reader & frame count ---
            number_of_frames, frame_iter = None, None
            with catch_oom("opening stack", filename):
                if is_tiff(in_path):
                    number_of_frames = count_tiff_frames(in_path)
                    frame_iter = iter_tiff_frames(in_path)
                    log(f'\nLoaded tiff stack with {number_of_frames} frames')
                elif is_nd2(in_path):
                    number_of_frames = count_nd2_frames(in_path)
                    frame_iter = iter_nd2_frames(in_path)
                    log(f'\nLoaded ND2 stack with ~{number_of_frames} planes (T*Z*C)')
                else:
                    log(f"Skipping unsupported file: {filename}")

            if frame_iter is None:
                print(f"⚠️  Skipping {filename} due to earlier error.")
                continue

            # Initialize patch lists for each model
            patches_list = [[] for _ in config.model_names]
            patch_indices_list = [[] for _ in config.model_names]
            frame_numbers = [[] for _ in config.model_names]

            # Initialize accumulator variables
            M, N = None, None
            sum_image = None
            patchwise_recon = None
            frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []
            total_selected_model_hist = np.zeros(len(config.model_names), dtype=float)

            # Progress bar for overall process
            pbar = tqdm(total=number_of_frames, desc="Processing frames")

            for frame_start in range(0, number_of_frames, config.frame_batch_size):
                frame_end = min(frame_start + config.frame_batch_size, number_of_frames)

                # Collect patches from multiple frames
                all_valid_patches = []
                all_full_patches = []
                all_patches_local_indices = []
                all_frames_numbers = []
                all_patches_offset = []

                # for DEBUG
                fproc_list = []
                offset_list = []

                frames_list = []
                # Process frames in this batch

                for frame_idx in range(frame_start, frame_end):
                    # Start preprocessing timing for EACH frame
                    profiler.start_timer("preprocessing.frame_reading_and_splitting")

                    # Get frame
                    frame_i = next(frame_iter)

                    # Initialize sum_image and dimensions on first frame
                    if sum_image is None:
                        sum_image = np.zeros_like(frame_i, dtype=np.float32)

                        # Interpolate first to get actual dimensions
                        if config.interpolate_based_on_imaging_parameters:
                            temp_frame = interpolate_frames(
                                frame_i,
                                model_pixel_size, config.pixel_size,
                                model_wavelength, config.wavelength,
                                model_NA, config.numerical_aperture
                            )[0]
                            M, N = temp_frame.shape
                        else:
                            M, N = frame_i.shape

                    # Accumulate for preview
                    sum_image += frame_i.astype(np.float32) / number_of_frames

                    # Interpolate frame
                    if config.interpolate_based_on_imaging_parameters:
                        frame_i = interpolate_frames(
                            frame_i,
                            model_pixel_size, config.pixel_size,
                            model_wavelength, config.wavelength,
                            model_NA, config.numerical_aperture
                        )[0]
                    frames_list.append(frame_i)

                # Preprocess on GPU
                frames_torch = torch.from_numpy(np.array(frames_list)).float().to(device)
                fproc_tensor, frames_offsets = preprocess_frames_batch(frames_torch, device)
                #fproc_tensor, frames_offsets = preprocess_frames_batch(frames_tensor, device)


                # Split all frames to patches (GPU)
                all_patches_tensor = split_image_to_patches_batch(
                    fproc_tensor,
                    config.num_patches,
                    config.overlap,
                    device=device
                )

                for frame_idx in range(frame_start, frame_end):
                    #fproc = fproc_list[frame_idx - frame_start]
                    offset = frames_offsets[frame_idx - frame_start].cpu().item()
                    #fproc = fproc_tensor[frame_idx - frame_start].cpu().numpy()
                    # Split into patches
                    patches = all_patches_tensor[frame_idx - frame_start]

                    # Process each patch
                    for m in range(config.num_patches):
                        for n in range(config.num_patches):
                            down = config.overlap if m == 0 else 0
                            up = (M // config.num_patches) - config.overlap if m == config.num_patches - 1 else (
                                    M // config.num_patches)
                            left = config.overlap if n == 0 else 0
                            right = (N // config.num_patches) - config.overlap if n == config.num_patches - 1 else (
                                    N // config.num_patches)

                            local_patch_idx = m * config.num_patches + n
                            full_patch = patches[local_patch_idx]
                            valid_patch = full_patch[down:up, left:right]

                            all_full_patches.append(full_patch)
                            all_valid_patches.append(valid_patch)
                            all_patches_local_indices.append(local_patch_idx)
                            all_patches_offset.append(offset)
                            all_frames_numbers.append(frame_idx)


                    pbar.update(1)

                profiler.stop_timer("preprocessing.frame_reading_and_splitting")
                profiler.start_timer("preprocessing.feature_extraction")

                # Group patches by size and extract features
                shape_groups = defaultdict(lambda: {'patches': [], 'indices': [], 'offsets': []})

                for idx, patch in enumerate(all_valid_patches):
                    shape = patch.shape
                    shape_groups[shape]['patches'].append(patch)
                    shape_groups[shape]['indices'].append(idx)
                    shape_groups[shape]['offsets'].append(all_patches_offset[idx])

                # Process each size group
                all_features = []

                for shape, group_data in shape_groups.items():
                    patches_tensor  = torch.stack(group_data['patches'])
                    offsets_array = np.array(group_data['offsets'])

                    features_batch = extract_features_batch(
                        patches_tensor,
                        config.pixel_size,
                        psf_sigma_pixels,
                        offsets_array,
                        verbose=False,
                        device=device
                    )

                    for feat, idx in zip(features_batch, group_data['indices']):
                        all_features.append((feat, idx))

                profiler.stop_timer("preprocessing.feature_extraction")
                profiler.start_timer("preprocessing.patch_classification")

                # Classify and accumulate patches for reconstruction
                for features, idx in all_features:
                    curr_mean_noise, curr_std_noise, signal_amp, curr_emitter_density = features

                    # Skip invalid patches
                    if signal_amp == 0 or curr_mean_noise == 0:
                        continue
                    if any(np.isnan(v) for v in (signal_amp, curr_mean_noise, curr_std_noise, curr_emitter_density)):
                        continue

                    # Choose difficulty level
                    difficulty_choice = ChooseNetByDifficulty_2025(
                        curr_emitter_density,
                        signal_amp / curr_mean_noise
                    )

                    # Store patch data
                    patches_list[difficulty_choice].append(all_full_patches[idx])
                    patch_indices_list[difficulty_choice].append(all_patches_local_indices[idx])
                    frame_numbers[difficulty_choice].append(all_frames_numbers[idx])

                profiler.stop_timer("preprocessing.patch_classification")

                # Initialize reconstruction array on first batch and move it to the GPU
                if patchwise_recon is None:
                    M, N = fproc_tensor.shape[1], fproc_tensor.shape[2]
                    patchwise_recon = torch.zeros(M * upsampling_factor, N * upsampling_factor,
                                                  dtype=torch.float32, device=device)

                # Check if there are any patches to process
                total_patches = sum(len(patches) for patches in patches_list)
                if total_patches > 0:
                    # Process with each model
                    for model_num, model_name in enumerate(config.model_names):
                        if not patches_list[model_num]:
                            continue

                        # Process in chunks
                        t_chunks = (len(patches_list[model_num]) // config.chunk_size) + 1

                        for chunk_num in range(t_chunks):
                            chunk_start = chunk_num * config.chunk_size
                            chunk_end = min((chunk_num + 1) * config.chunk_size, len(patches_list[model_num]))

                            if chunk_start >= chunk_end:
                                continue

                            # Reconstruct using CACHED model
                            pw_recon, loc_list = reconstruct_patches_2025_pytorch(
                                torch.stack(patches_list[model_num][chunk_start:chunk_end]),
                                patch_indices_list[model_num][chunk_start:chunk_end],
                                frame_numbers[model_num][chunk_start:chunk_end],
                                model_num,
                                config.num_patches,
                                config.overlap * upsampling_factor,
                                number_of_frames,
                                config.threshold,
                                neighborhood_size=config.neighborhood_size,
                                use_local_avg=config.use_local_average,
                                upsampling_factor=upsampling_factor,
                                pixel_size=config.pixel_size,
                                batch_size=config.patch_batch_size,
                                L2_weighting_factor=L2_weighting_factor,
                                profiler=profiler
                            )

                            # Accumulate results
                            frame_number_list.extend(loc_list[0])
                            x_nm_list.extend(loc_list[1])
                            y_nm_list.extend(loc_list[2])
                            confidence_au_list.extend(loc_list[3])

                            patchwise_recon[:M // config.num_patches * upsampling_factor * config.num_patches,
                                            :N // config.num_patches * upsampling_factor * config.num_patches] += pw_recon

                    # Clear patches lists after processing to free memory
                    for i in range(len(patches_list)):
                        patches_list[i].clear()
                        patch_indices_list[i].clear()
                        frame_numbers[i].clear()

                    ## Force garbage collection
                    #if torch.cuda.is_available():
                    #    torch.cuda.empty_cache()

            # close progress bar
            pbar.close()

            if not os.path.exists(config.Result_folder):
                print('Result folder was created.')
                os.makedirs(config.Result_folder, exist_ok=True)

            print(f"\n{'=' * 70}")
            print(f"Streaming processing complete for {filename}")
            print(f"Total localizations found: {len(frame_number_list)}")
            print(f"{'=' * 70}")

            # Save results
            os.makedirs(config.Result_folder, exist_ok=True)
            ext = '_avg' if config.use_local_average else '_max'
            base = os.path.splitext(filename)[0]

            # Save localizations
            with open(os.path.join(config.Result_folder, f'Localizations_{base}{ext}.csv'), "w", newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])
                sort_ind = np.argsort(frame_number_list)
                locs = list(zip(
                    list(np.array(frame_number_list)[sort_ind]),
                    list(np.array(x_nm_list)[sort_ind]),
                    list(np.array(y_nm_list)[sort_ind]),
                    list(np.array(confidence_au_list)[sort_ind])
                ))
                writer.writerows(locs)

            print(f"Saved {len(frame_number_list)} localizations")

            # move image to CPU for saving
            patchwise_recon = patchwise_recon.cpu().numpy()

            # Save reconstruction
            pw_recon_tif = np.copy(patchwise_recon)
            cap = np.percentile(pw_recon_tif, 99.5)
            pw_recon_tif[pw_recon_tif > cap] = cap
            saveAsTIF(config.Result_folder, f"Predicted_patchwise_{base}", pw_recon_tif, config.pixel_size / upsampling_factor)

            # Create preview
            fig, axes = plt.subplots(1, 3, figsize=(20, 16))
            axes[0].axis('off')
            axes[0].imshow(sum_image)
            axes[0].set_title('Original', fontsize=15)
            axes[1].axis('off')
            axes[1].imshow(patchwise_recon)
            axes[1].set_title('Prediction', fontsize=15)
            axes[2].axis('off')
            axes[2].imshow(np.clip(patchwise_recon,
                                   np.percentile(patchwise_recon, 1),
                                   np.percentile(patchwise_recon, 99)))
            axes[2].set_title('Normalized Prediction', fontsize=15)
            plt.tight_layout()
            plt.savefig(os.path.join(config.Result_folder, f'preview_{base}.png'), dpi=150)
            plt.close()

            print(f"\nCompleted processing: {filename}")

            # Print timing summary at the end
            profiler.print_timing_summary()

import gc
import torch

# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()



