# Semi-synthetic parameter sensitivity analysis & benchmark

In [3]:
#!python -m pip install numpy pandas scikit-learn simutome spellmatch tqdm

In [2]:
from pathlib import Path
from timeit import default_timer as timer

import numpy as np
import pandas as pd
from simutome import Simutome
from sklearn.model_selection import ParameterGrid
from spellmatch.assignment import assign, AssignmentStrategy
from spellmatch.matching.algorithms.icp import IterativeClosestPoints
from spellmatch.matching.algorithms.probreg import RigidCoherentPointDrift
from spellmatch.matching.algorithms.spellmatch import Spellmatch
from tqdm.auto import tqdm

rng = np.random.default_rng(seed=123)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
source_points_dir = "source_points"
source_clusters_dir = "source_clusters"
source_intensities_dir = "source_intensities"

target_data_param_grid = ParameterGrid(
    {
        # simulate minor mis-alignments
        "image_transform": [
            {
                "image_scale": (1.0, 1.0),
                "image_rotation": image_rotation_degrees * np.pi / 180,
                "image_shear": 0.0,
                "image_translation": image_translation,
            }
            for image_rotation_degrees in [0.0, 2.0]
            for image_translation in [(0.0, 0.0), (5.0, 5.0)]
        ],
        # see ../kuett_catena_2022/parameters.ipynb
        "cell_exclusion": [
            {
                "exclude_cells": False,
            },
            {
                "exclude_cells": True,
                "cell_diameter_mean": 7.931,
                "cell_diameter_std": 1.768,
            },
        ],
        # see ../kuett_catena_2022/parameters.ipynb
        "cell_displacement": [
            {
                "displace_cells": False,
            },
            {
                "displace_cells": True,
                "cell_displacement_mean": 0,
                "cell_displacement_var": 1.0,
            },
        ],
        # mis-segmentation / physical separation of "U-shaped" cells
        "cell_division": [
            {
                "cell_division_probab": 0.0,
            },
            {
                "cell_division_probab": 0.05,
                "cell_division_dist_mean": 6.0,
                "cell_division_dist_std": 1.0,
            },
        ],
        # permute a fraction of cells to see intensity-related effects
        "cell_swapping": [
            {
                "cell_swapping_probab": 0.0,
            },
            {
                "cell_swapping_probab": 0.2,
            },
        ],
    }
)

section_thickness = 2.0
n_sections = 1

In [None]:
source_points_files = sorted(Path(source_points_dir).glob("*.csv"))
source_clusters_files = sorted(Path(source_clusters_dir).glob("*.csv"))
source_intensities_files = sorted(Path(source_intensities_dir).glob("*.csv"))
assert len(source_points_files) == len(source_clusters_files) == len(source_intensities_files)

def evaluate_algorithm(algorithm, assignment_dict, metric_dict, **tqdm_kwargs):
    results = []
    pbar = tqdm(total=len(target_data_param_grid) * len(source_points_files) * n_sections, **tqdm_kwargs)
    for target_data_params in target_data_param_grid:
        simutome_kwargs = {k: v for p in target_data_params.values() for k, v in p.items()}
        simutome = Simutome(**simutome_kwargs, shuffle_cells=True, seed=rng)
        for source_points_file, source_clusters_file, source_intensities_file in zip(
            source_points_files, source_clusters_files, source_intensities_files
        ):
            source_points = pd.read_csv(source_points_file, index_col="cell")
            source_clusters = pd.read_csv(source_clusters_file, index_col="cell")
            source_intensities = pd.read_csv(source_intensities_file, index_col="cell")
            section_generator = simutome.generate_sections(
                source_points.to_numpy(),
                section_thickness,
                cell_intensities=source_intensities.loc[source_points.index, :].to_numpy(),
                cell_clusters=source_clusters.loc[source_points.index, :].iloc[:, 0].to_numpy(),
                n=n_sections,
            )
            for section_number, (cell_indices, cell_coords, cell_intensities) in enumerate(section_generator):
                target_points = pd.DataFrame(
                    cell_coords,
                    index=source_points.index[cell_indices],
                    columns=source_points.columns,
                )
                target_intensities = pd.DataFrame(
                    cell_intensities,
                    index=source_intensities.index[cell_indices],
                    columns=source_intensities.columns,
                )
                start = timer()
                scores = algorithm.match_points(
                    source_points,
                    target_points,
                    source_intensities=source_intensities,
                    target_intensities=target_intensities,
                )
                end = timer()
                for assignment_name, assignment_func in assignment_dict.items():
                    assignment = assignment_func(scores)
                    for metric_name, metric_func in metric_dict.items():
                        metric_value = metric_func(scores, assignment)
                        results.append(
                            {
                                **simutome_kwargs,
                                "source_points_file": source_points_file.name,
                                "source_clusters_file": source_clusters_file.name,
                                "source_intensities_file": source_intensities_file.name,
                                "section_number": section_number,
                                "seconds": end - start,
                                "assignment_name": assignment_name,
                                "metric_name": metric_name,
                                "metric_value": metric_value,
                            }
                 pbar.update()
    pbar.close()
    return results

## Parameter sensitivity analysis

In [None]:
spellmatch_param_grid = ParameterGrid(
    {
        "degrees": [
            {
                "degree_weight": 0.0,
            },
            {
                "degree_weight": 1.0,
                "degree_cdiff_thres": 3,
            },
        ],
        "intensities": [
            {
                "intensity_weight": 0.0,
            },
            {
                "intensity_weight": 1.0,
                "intensity_interp_lmd": 0.0,
            },
            {
                "intensity_weight": 1.0,
                "intensity_interp_lmd": 0.5,
            },
            {
                "intensity_weight": 1.0,
                "intensity_interp_lmd": 1.0,
            },
        ],
        "distances": [
            {
                "distance_weight": 0.0,
            },
            {
                "distance_weight": 1.0,
                "distance_cdiff_thres": 5,
            },
        ],
    },
)

# assignment_dict = {
#     "threshold"
#     "forward_max": partial(assign, strategy=AssignmentStrategy.FORWARD_MAX, score_thres=0),
# }

# metric_dict = {
#     "precision": lambda scores, assignment:,
#     "recall": lambda scores, assignment:,
#     "f1score": lambda scores, assignment:,
#     "accuracy": lambda scores, assignment:,
#     "uncertainty_mean": lambda scores, assignment:,
#     "uncertainty_std": lambda scores, assignment:,
#     "margin_mean": lambda scores, assignment:,
#     "margin_std": lambda scores, assignment:,
#     "entropy_mean": lambda scores, assignment:,
#     "entropy_std": lambda scores, assignment:,
# }

In [None]:
psa_results = []
for i, spellmatch_params in enumerate(spellmatch_param_grid):
    spellmatch_kwargs = {k: v for p in spellmatch_params.values() for k, v in p.items()}
    spellmatch = Spellmatch(
        adj_radius=15,
        alpha=0.8
        intensity_transform=np.log1p,
        spatial_cdist_prior_thres=25,
        max_spatial_cdist=50,
        scores_tol=1e-6,
        filter_outliers=False,
        **spellmatch_kwargs,
    )
    current_results = evaluate_algorithm(
        spellmatch,
        assignment_dict,
        metric_dict,
        desc=f"{i}/{len(spellmatch_param_grid)}",
    )
    for result in current_results:
        psa_results.append({**spellmatch_kwargs, **result})
psa_results = pd.DataFrame(data=psa_results)
psa_results.to_csv("psa.csv", index=False)

## Benchmark

In [None]:
algorithm_dict = {
    "ICP": IterativeClosestPoints(),  # TODO
    "RigidCPD": RigidCoherentPointDrift(),  # TODO
    "Spellmatch": Spellmatch(),  # TODO
}

In [None]:
benchmark_results = []
for algorithm_name, algorithm in algorithm_dict.items():
    current_results = evaluate_algorithm(algorithm, assignment_dict, metric_dict, desc=algorithm_name)
    for result in current_results:
        benchmark_results.append({"algorithm": algorithm_name, **result})
benchmark_results = pd.DataFrame(data=benchmark_results)
benchmark_results.to_csv("benchmark.csv", index=False)