This notebook evaluates the localization & diffusion predictions from our trained UNet_locD.  

## Table of Contents

- [1. Import & Model Setup](#Import-&-Model-Setup)  
- [2. Data Setup](#Data-Setup)  
- [3. Localize](#Localize)  
- [4. KNN Matching](#KNN-Matching)  
- [5. Evaluation](#Evaluation)  
- [6. Visualizations](#Visualizations)  


# Import-&-Model-Setup

In [None]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score
from scipy.stats import gaussian_kde

from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
import wandb
import matplotlib.pyplot as plt
from model import unet_locD
from utils.data_loader import data_loader
from utils.loss_calculator import calculate_loss

import time


# Data-Setup

In [None]:
# Define the file path

dir_img = 'path/imgpadding/'        # Set where the image files are, by default should be in the folder imgpadding
dir_label = 'path/labelpading/'     # Set where the image files are, by default should be in the folder labelpading
dir_pair = 'path/pairpading/'       # Set where the image files are, by default should be in the folder pairpading


# Set up the parameter:
D_range = [0.01, 2]  # For 1216data

label_suffix = '_loc'
pair_suffix = '_pair'

# load the model 
dir_model = 'path/model_checkpoint_3.pth'  # Set where the trained model is.


In [None]:
# Import model 

# 1. load the model 
# Initialize the model 
model = unet_locD(n_channels=3, n_classes=1,bilinear=False)
# Load the saved state_dict
model.load_state_dict(torch.load(dir_model,map_location=torch.device('cpu') ))
# Set to evaluation mode
model.eval()

# 3.  Perform evaluation
# Set up the device: 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load the data
data = data_loader(dir_img, dir_label,dir_pair, label_suffix, pair_suffix)
dataset = DataLoader(data)

In [None]:
# Take one example frame.
i = 0
example_frame = 3
for batch in dataset:
    img = batch['image']
    labell = batch['label']
    pair = batch['pair']
    i += 1
    if i == example_frame:
        break


In [None]:
print(f'The label size is {labell.shape}.')
print(f'The image size is {img.shape}')
print(f'The particle label pair size is {pair.shape}')

# The label size should be 14*64*64
# The image size should be 3*64*64
# The particle label pair size should be 3 * 10 (# of particles)

# If any shape is not right, do permute accordingly. 

In [None]:
labell = labell.permute((0,1,3,2))
pair = pair.permute((0,2,1))

In [None]:
# Get an output from the model: 
result = model(img)

# Check how it performs: 

label_data = labell.detach().numpy()[0,:,:,:]
label_map = np.sum(label_data, 0)
img_data = img.detach().numpy()[0,0,:,:]
img_data_sum = np.sum(img.detach().numpy()[0],0)

output = torch.sigmoid(result).detach().numpy()[0,:,:,:]
output_map = np.sum(output,0)


# Pick which image to check 
plt.imshow(output_map,alpha = 0.5 )  # output results 
plt.imshow(label_map, alpha = 0.5)   # label 
plt.imshow(img_data_sum,alpha = 0.5) # image input 


pair_data = pair.detach().numpy()[0,:,:]
for emitter in range(pair_data.shape[1]):
    x, y, D= pair_data[:,emitter]
    plt.scatter(x - 1, y - 1, color = 'red', s = 5)

# If the label and input doesn't match, go back to adjust the permute

In [None]:
#Same for the D labels.
label_mapX = np.sum(label_data,1)

outputX = np.sum(output,1)
#plt.imshow(outputX,alpha = 0.5 )
plt.imshow(label_mapX, alpha = 0.5)    
#plt.imshow(img_data,alpha = 1)


pair_data = pair.detach().numpy()[0,:,:]
for emitter in range(pair_data.shape[1]):
    x, y, D= pair_data[:,emitter]
    D_cal = (D  - D_range[0])/(D_range[1] - D_range[0]) * 9 + 2
    plt.scatter( x-1, D_cal,color = 'red', s = 5)


# Localize

In [None]:
import numpy as np
from scipy.ndimage import maximum_filter, label, center_of_mass

def find_particle_locations(gaussian_map, neighborhood_size = 3, threshold=0.1):
    
    """
    Recover particle locations from a Gaussian probability map.
    
    Args:
    - gaussian_map: 3D numpy array of the Gaussian probability map.
    - neighborhood_size: Size of the neighborhood to consider for local maxima.
    - threshold: Minimum probability to consider a peak as valid.

    Returns:
    - refined_coords: List of tuples with (x, y) subpixel particle locations.
    """
    
    # Step 1: Find local max 
    neighborhood = np.ones((neighborhood_size, neighborhood_size,neighborhood_size))
    local_max = (gaussian_map == maximum_filter(gaussian_map, footprint=neighborhood))
    
    # Step 2: Apply threshold to filter noise
    threshold_mask = gaussian_map > threshold
    valid_peaks = local_max & threshold_mask
    
    # Step 3: Label connected regions (for subpixel refinement)
    labeled_array, num_features = label(valid_peaks)
    
    # Step 4: Subpixel refinement using center of mass
    refined_coords = []
    for i in range(1, num_features + 1):  # Label indices start at 1
        mask = (labeled_array == i)
        if mask.any():
            y, x, z = center_of_mass(gaussian_map, mask)  # Compute center of mass
            refined_coords.append([x, y, z])
    
    return refined_coords


In [None]:

# Parameters
neighborhood_size = 1  # Adjust for PSF size
threshold = 0.1  # Adjust based on noise level

# Find particle locations
particle_locations = find_particle_locations(output.transpose((1,2,0)), neighborhood_size, threshold)


plt.imshow(label_map,alpha = 0.5, cmap = 'hot')
plt.imshow(img_data, alpha = 0.5, cmap = 'gray')


# Print results
for i, (x,y,d) in enumerate(particle_locations):
    print(f"Particle {i+1}: x = {x:.2f}, y = {y:.2f}, d = {d:.2f}")
    plt.scatter(x,y , s = 5,color = 'black')


# Plot the original label: 

pair_data = pair.detach().numpy()[0,:,:]
for emitter in range(pair_data.shape[1]):
    x, y, D = pair_data[:,emitter]
    plt.scatter(x-1,y-1, s = 5, color = 'red')




In [None]:
import matplotlib.patches as patches

# Plot particle locations
for i, (x, y, d) in enumerate(particle_locations):
    plt.scatter(x, y, s=200, edgecolor='red', facecolor='none', linewidth=1)
    D_value = (d / 9 ) * (D_range[1] - D_range[0]) + D_range[0]
    plt.text(x-2, y+3 , f"{D_value:.2f}", color = 'red')

img_show = img.detach().numpy()[0,:,:,:]
img_show = np.sum(img_show,0)
# Display the image
plt.imshow(img_show, cmap='gray')

# Remove axis labels and ticks
plt.axis('off')

# Add a scale bar
scale_bar_length_pixels = 50  # Length of the scale bar in pixels (e.g., 100 pixels)
scale_bar_length_um = scale_bar_length_pixels * 0.1  # Convert pixels to um
scale_bar = patches.Rectangle((10, 240), scale_bar_length_pixels, 5, linewidth=0,
                               edgecolor=None, facecolor='white')  # Position at (10, 10)
plt.gca().add_patch(scale_bar)

# Add scale bar text
#plt.text(10 + scale_bar_length_pixels / 2, 235, f"{scale_bar_length_um} μm", color='white',
        # ha='center', fontsize=10)

# Show the plot
plt.show()

In [None]:

# Parameters
neighborhood_size = 1 # Adjust for PSF size
threshold = 0.1  # Adjust based on noise level

# Find particle locations
# particle_locations = find_particle_locations(label_data.transpose((1,2,0)), neighborhood_size, threshold)
particle_locations = find_particle_locations(output.transpose((1,2,0)), neighborhood_size, threshold)


plt.imshow(label_mapX,alpha = 0.5, cmap = 'hot')
# plt.imshow(img_data, alpha = 0.5, cmap = 'gray')


# Print results
for i, (x, y, d) in enumerate(particle_locations):
    print(f"Particle {i+1}: x = {x:.2f}, y = {y:.2f}")
    plt.scatter(x,d, s = 5,color = 'black')


# Plot the original label: 


pair_data = pair.detach().numpy()[0,:,:]
for emitter in range(pair_data.shape[1]):
    x, y, D = pair_data[:,emitter]
    D = (D - D_range[0])/(D_range[1] - D_range[0]) * 9 + 2
    plt.scatter(x-1, D, s = 5, color = 'red')

# KNN-Matching

In [None]:

def evaluate_location(predicted_pair, true_pair, D_range, threshold=1):
    # A->B KNN
    nn_model_A_to_B = NearestNeighbors(n_neighbors=1)
    nn_model_A_to_B.fit(true_pair[:, :2])
    distances_A_to_B, indices_A_to_B = nn_model_A_to_B.kneighbors(predicted_pair[:, :2])

    # B->A KNN
    nn_model_B_to_A = NearestNeighbors(n_neighbors=1)
    nn_model_B_to_A.fit(predicted_pair[:, :2])
    distances_B_to_A, indices_B_to_A = nn_model_B_to_A.kneighbors(true_pair[:, :2])

    # Flatten since they are (N,1) arrays
    distances_A_to_B = distances_A_to_B.ravel()
    indices_A_to_B = indices_A_to_B.ravel()
    distances_B_to_A = distances_B_to_A.ravel()

    # True positives, false positives, and false negatives
    true_positive = np.sum(distances_A_to_B < threshold)
    false_positive = np.sum(distances_A_to_B >= threshold)
    false_negative = np.sum(distances_B_to_A >= threshold)

    # Mask for matched predictions (distances under threshold)
    mask = distances_A_to_B < threshold

    # Extract matched predictions and corresponding true pairs
    matched_pred = predicted_pair[mask]
    matched_true = true_pair[indices_A_to_B[mask]]

    # Compute delta x, delta y
    dx = matched_pred[:, 0] - matched_true[:, 0]
    dy = matched_pred[:, 1] - matched_true[:, 1]

    # Compute predicted diffusion and its difference from truth
    #D_pred = (matched_pred[:, 2] / 9.0) * (D_range[1] - D_range[0]) + D_range[0]
    D_pred = ((matched_pred[:, 2] - 2) / 9.0) * (D_range[1] - D_range[0]) + D_range[0]
    dd = D_pred - matched_true[:, 2]

    # Distances for matched points
    d = distances_A_to_B[mask]

    return (true_positive,
            false_positive,
            false_negative,
            dx.tolist(),
            dy.tolist(),
            dd.tolist(),
            d.tolist(),
            D_pred.tolist(),
            matched_true[:, 2].tolist())


# Evaluation

In [None]:
true_positive_sum = 0
false_positve_sum = 0
false_negative_sum = 0 
dx_list_all = []
dy_list_all = []
dd_list_all = []
d_list_all = []

D_pred_all = []
D_true_all = []
# Parameters
neighborhood_size = 1  # Adjust for PSF size
threshold = 0.1  # Adjust based on noise level  # 0.1 for diffusion prediction

dataset =  data_loader(dir_img, dir_label,dir_pair, label_suffix, pair_suffix)
val_percent = 0.1
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
val_loader = DataLoader(val_set, shuffle=False, drop_last=True)

for batch in val_loader:
    img = batch['image']
    labell = batch['label']
    pair = batch['pair']

    # Use the permute as before

    labell = labell.permute((0,1,3,2))
    pair = pair.permute((0,2,1))


    result = model(img)

    label_data = labell.detach().numpy()[0]
    output = torch.sigmoid(result).detach().numpy()[0]
    pair_data = pair.detach().numpy()[0]

    

    # Find particle locations
    
    predicted_pair = np.array(find_particle_locations(output.transpose((1,2,0)), neighborhood_size, threshold))
    true_pair = []
    for emitter in range(pair_data.shape[1]):
        x, y, d = pair_data[:,emitter]
        #true_pair.append([y-1, x-1, d])   # There's a stupid off set for matlab and python index . Also x and y is reversed. 
        true_pair.append([x-1, y-1, d])   # xy off set + d off set for one! 
    true_pair = np.array(true_pair)
    if len(predicted_pair) == 0:
        predicted_pair = np.zeros([10,3])
        
    true_positive, false_positve, false_negative, dx_list, dy_list, dd_list, d_list, diff_pred, diff_true = evaluate_location(predicted_pair, true_pair, D_range, threshold= 1.5)

    true_positive_sum += true_positive
    false_positve_sum += false_positve
    false_negative_sum += false_negative
    dx_list_all.extend(dx_list)
    dy_list_all.extend(dy_list)
    d_list_all.extend(d_list)

    dd_list_all.extend(dd_list) 
    D_pred_all.extend(diff_pred)
    D_true_all.extend(diff_true)
    #break
    
    

# Visualizations

In [None]:
f1_score = (2 * true_positive_sum) / (2 * true_positive_sum + false_positve_sum + false_negative_sum)
print(f'The f1 score is {f1_score:.2f}')

In [None]:
dx_array = np.array(dx_list_all) * 100   # nm 

# Define bin width and calculate bins
bin_width = 20
bins = np.arange(min(dx_array), max(dx_array) + bin_width, bin_width)

# Compute the histogram without density normalization
hist, bin_edges = np.histogram(dx_array, bins=bins, density=False)

# Normalize the histogram to have a maximum value of 1
hist_normalized = hist / hist.max()

# Plot the histogram
plt.bar(bin_edges[:-1], hist_normalized, width=bin_width, edgecolor='black', alpha=0.7)
#plt.title('Normalized Histogram of dx_list_all')
plt.xlabel('dx (nm)',fontsize=18)
plt.ylabel('Normalized Frequency',fontsize=18)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()


In [None]:
# x localization precision 
dx_array = np.array(dx_list_all)
rmse_x = np.sqrt(np.mean(dx_array**2))
print(f"The x localization error is {rmse_x * 100: .0f} nm.")


In [None]:
dy_array = np.array(dy_list_all) * 100  #nm

# Define bin width and calculate bins
bin_width = 20
bins = np.arange(min(dy_array), max(dy_array) + bin_width, bin_width)

# Compute the histogram without density normalization
hist, bin_edges = np.histogram(dy_array, bins=bins, density=False)

# Normalize the histogram to have a maximum value of 1
hist_normalized = hist / hist.max()

# Plot the histogram
plt.bar(bin_edges[:-1], hist_normalized, width=bin_width, edgecolor='black', facecolor = 'orange', alpha=0.7)
#plt.title('Normalized Histogram of dx_list_all')
plt.xlabel('dy (nm)',fontsize=18)
plt.ylabel('Normalized Frequency',fontsize=18)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

In [None]:
dy_array = np.array(dy_list_all)
rmse_y = np.sqrt(np.mean(dy_array**2))
print(f"The y localization error is {rmse_y * 100: .0f} nm.")

In [None]:
# Calculate the r2:
def scatter_with_gaussian_kde( pair, s, diffusion_range):

    array = np.array(pair)
    xy = np.vstack(array).T
    z = gaussian_kde(xy)(xy)
    R2_sklearn = r2_score(xy[0], xy[1])

    plt.scatter( *zip(*array),c = z, s = s)
    plt.ylim([0,2])
    plt.xlabel('Ground Truth diffusion coeficient ($μm^2$/s)')
    plt.ylabel('Predicted Diffusion coeficient($μm^2$/s) ')
    return R2_sklearn

In [None]:
# Diffusion pair: 
diffusion_pair = [(a,b) for a,b in zip( D_true_all,D_pred_all)]

R2 = scatter_with_gaussian_kde(diffusion_pair,10,D_range)

plt.plot([0,2],[0,2], color = 'red')

In [None]:
print(f"The R^2 value is {R2:.2f}.")