In [1]:
"""
make_lowlevel_tables.py

Builds low level fidelity tables for each dataset.
Rows = Method with an Implant column
Columns = metrics
One LaTeX table per dataset.

Edit DATASET_PATHS near the bottom before running.
"""

import os
import sys
import math
import json
import pickle
from typing import Dict, Union, List, Tuple

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder, CocoDetection
from torchvision.datasets.folder import default_loader

import cv2
import piq

# Local project imports
sys.path.append('./..')
sys.path.append('./../..')
from dynaphos import utils
from dynaphos.simulator import GaussianSimulator as PhospheneSimulator
from phosphene.uniformity import DynamicAmplitudeNormalizer
from phosphene.density import VisualFieldMapper
from spatial_frequency.components.SeparableModulated2d import SeparableModulatedConv2d
from utils import robust_percentile_normalization  # uses your existing implementation

In [2]:
# ------------- Determinism and device -------------
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
torch.use_deterministic_algorithms(False)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ------------- Transforms -------------
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722]).view(3, 1, 1)
to_weighted_grayscale = transforms.Lambda(lambda img: (img * rgb_weights).sum(dim=0, keepdim=True))

T_IMG = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    to_weighted_grayscale,   # -> [1, H, W], in [0, 1]
])

IMG_EXTS = {".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"}

In [3]:
class FlatImageDataset(Dataset):
    """
    Recursively loads all image files under root. Ignores labels.
    Returns (tensor, 0).
    """
    def __init__(self, root: str, transform=None):
        self.root = root
        self.transform = transform
        self.paths = []
        for dirpath, _, fnames in os.walk(root):
            for f in fnames:
                if os.path.splitext(f)[1].lower() in IMG_EXTS:
                    self.paths.append(os.path.join(dirpath, f))
        if len(self.paths) == 0:
            raise FileNotFoundError(f"No images found under {root}")
        self.loader = default_loader

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = self.loader(self.paths[idx])
        if self.transform is not None:
            img = self.transform(img)
        return img, 0


def make_coco_loader(data_root: str, split: str = "val2017",
                     batch_size: int = 1, shuffle: bool = False) -> DataLoader:
    img_dir = os.path.join(data_root, split)
    ann_file = os.path.join(data_root, "annotations", f"instances_{split}.json")
    if not os.path.exists(img_dir):
        raise FileNotFoundError(f"COCO images not found at {img_dir}")
    if not os.path.exists(ann_file):
        raise FileNotFoundError(f"COCO annotations not found at {ann_file}")
    ds = CocoDetection(root=img_dir, annFile=ann_file, transform=T_IMG)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)


def make_lapa_loader(root: str, batch_size: int = 1, shuffle: bool = False) -> DataLoader:
    try:
        ds = ImageFolder(root, transform=T_IMG)
        if len(ds) == 0:
            raise FileNotFoundError("ImageFolder found no images")
    except Exception:
        ds = FlatImageDataset(root, transform=T_IMG)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)


def make_sun_loader(root: str, batch_size: int = 1, shuffle: bool = False) -> DataLoader:
    try:
        ds = ImageFolder(root, transform=T_IMG)
        if len(ds) == 0:
            raise FileNotFoundError("ImageFolder found no images")
    except Exception:
        ds = FlatImageDataset(root, transform=T_IMG)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

def build_dataset_loaders(paths: Dict[str, str]) -> Dict[str, DataLoader]:
    loaders = {}
    if "COCO" in paths:
        loaders["COCO"] = make_coco_loader(paths["COCO"], split="val2017", batch_size=1, shuffle=False)
    if "LaPa" in paths:
        loaders["LaPa"] = make_lapa_loader(paths["LaPa"], batch_size=1, shuffle=False)
    if "SUN" in paths:
        loaders["SUN"] = make_sun_loader(paths["SUN"], batch_size=1, shuffle=False)
    return loaders


In [4]:
# ------------- Implant schemes -------------
SCHEMES: List[Tuple[str, str, float]] = [
    ("1 Utah", "../electrode_schemes/1utaharray.pickle",                 0.4),
    ("Utah RFs", "../electrode_schemes/utahRFs.pickle",                  6.0),
    ("4 Utah", "../electrode_schemes/4utaharrays.pickle",               16.0),
    ("Uniform 1024", "../electrode_schemes/defaultcoordinatemap_1024.pickle", 16.0),
    ("Neuralink", "../electrode_schemes/neuralink.pickle",              25.0),
]

In [5]:
# ------------- SCAPE and simulation helpers -------------
def dilation3x3(img: torch.Tensor, kernel: torch.Tensor = None) -> torch.Tensor:
    if kernel is None:
        kernel = torch.tensor([[[[0, 1, 0],
                                 [1, 1, 1],
                                 [0, 1, 0]]]], device=img.device, dtype=img.dtype)
    return torch.clamp(torch.nn.functional.conv2d(img, kernel, padding=1), 0, 1)


def rand_perlin_2d(shape: Tuple[int, int], res: Tuple[int, int],
                   fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3) -> torch.Tensor:
    H, W = shape
    delta = (res[0] / H, res[1] / W)
    d = (H // res[0], W // res[1])
    grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]),
                                      torch.arange(0, res[1], delta[1]),
                                      indexing='ij'), dim=-1) % 1
    angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
    gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)

    def tile_grads(s1, s2):
        return gradients[s1[0]:s1[1], s2[0]:s2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)

    def dot(grad, shift):
        g = grad[:H, :W]
        return (torch.stack((grid[:H, :W, 0] + shift[0], grid[:H, :W, 1] + shift[1]), dim=-1) * g).sum(dim=-1)

    n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
    t = fade(grid[:H, :W])
    return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])


In [6]:
def build_methods(orig_01: torch.Tensor, mod_layer: SeparableModulatedConv2d) -> Dict[str, torch.Tensor]:
    """
    orig_01: [1, 1, H, W] in [0, 1]
    returns dict of method name -> stimulus [1, 1, H, W] in [0, 1]
    """
    # SCAPE DoG
    dog = mod_layer(orig_01).detach()
    dog = (dog - dog.amin()) / (dog.amax() - dog.amin() + 1e-12)

    # Canny
    npimg = (orig_01[0, 0].detach().cpu().numpy() * 255).astype(np.uint8)
    ce = cv2.Canny(npimg, 150, 280).astype(np.float32)
    ce = torch.from_numpy(ce).to(orig_01.device).unsqueeze(0).unsqueeze(0)
    ce = dilation3x3(ce)
    ce = ce / (ce.amax() + 1e-12)

    # Random Perlin
    rp = rand_perlin_2d(orig_01.shape[-2:], (4, 4)).to(orig_01.device).float()
    rp = (rp - rp.amin()) / (rp.amax() - rp.amin() + 1e-12)
    rp = rp.unsqueeze(0).unsqueeze(0)

    return {
        "Grayscale": orig_01,
        "Canny": ce,
        "DoG": dog,
        "Random": rp,
    }

In [7]:
@torch.no_grad()
def simulate(simulator: PhospheneSimulator,
             stim_img: torch.Tensor,
             amplitude: float,
             threshold: float,
             weights: torch.Tensor) -> torch.Tensor:
    """
    stim_img: [1, 1, H, W] in [0, 1]
    returns phosphenes [1, 1, H, W]
    """
    torch.use_deterministic_algorithms(False)
    simulator.reset()
    elec = simulator.sample_stimulus(stim_img, rescale=True)  # [N]
    elec = robust_percentile_normalization(
        elec, amplitude, threshold=threshold,
        low_perc=5, high_perc=90, gamma=1/3
    )
    elec = elec * weights
    phos = simulator(elec)
    if phos.dim() == 2:
        phos = phos.unsqueeze(0).unsqueeze(0)
    elif phos.dim() == 3:
        phos = phos.unsqueeze(0)
    return phos

In [8]:
def build_sigma_map(simulator: PhospheneSimulator) -> torch.Tensor:
    """
    Builds sigma map in pixels using KDE with k=16, alpha=1.0 and beta=0.55.
    """
    mapper = VisualFieldMapper(simulator=simulator)
    dens = mapper.build_density_map_kde(k=16, alpha=1.0, total_phosphenes=simulator.num_phosphenes)
    try:
        sigma_pix = mapper.build_sigma_map_from_density(dens, space="pixel", beta=0.55)
    except TypeError:
        # If your mapper has beta baked in, call without it
        sigma_pix = mapper.build_sigma_map_from_density(dens, space="pixel")
    sigma_map_tensor = torch.tensor(sigma_pix, device=DEVICE, dtype=torch.float32)
    return sigma_map_tensor


In [9]:
# ------------- Metrics -------------
def make_metric_objects(keys: Tuple[str, ...]) -> Dict[str, torch.nn.Module]:
    objs = {}
    for key in keys:
        if key == 'fsim':
            objs[key] = piq.FSIMLoss(chromatic=False, min_length=7, scales=4).to(DEVICE)
        elif key == 'pieapp':
            objs[key] = piq.PieAPP().to(DEVICE)
        elif key == 'content':
            objs[key] = piq.ContentLoss(feature_extractor='vgg19',
                                        normalize_features=False,
                                        layers=['relu2_2'],
                                        distance='swd').to(DEVICE)
        elif key == 'srsim':
            objs[key] = piq.SRSIMLoss().to(DEVICE)
        elif key == 'vsi':
            objs[key] = piq.VSILoss().to(DEVICE)
        elif key == 'mdsi':
            objs[key] = piq.MDSILoss().to(DEVICE)
        else:
            raise ValueError(f"Unknown metric key: {key}")
    return objs


def safe_loss(loss_fn: torch.nn.Module, x: torch.Tensor, y: torch.Tensor) -> float:
    x = x.clamp(0, 1)
    y = y.clamp(0, 1)
    try:
        v = loss_fn(x, y).detach().item()
        return float(v)
    except Exception as e:
        print(f"[warn] metric {type(loss_fn).__name__} failed: {e}")
        return float('nan')

In [10]:
# ------------- Evaluation core -------------
def eval_dataset_to_table(dataset_name: str,
                          data_loader: DataLoader,
                          schemes: List[Tuple[str, str, float]],
                          params_path: str = "../config/params.yaml",
                          max_images: int = 200,
                          metrics: Tuple[str, ...] = ('ssim', 'fsim', 'mdsi', 'vsi', 'srsim', 'content'),
                          include_se: bool = True,
                          verbose: bool = True) -> Tuple[pd.DataFrame, str]:
    """
    Runs SCAPE for each scheme, simulates phosphenes for four methods, compares to original with metrics.
    Returns a tall DataFrame (index = Method, Implant) and a LaTeX table string.
    """
    params0 = utils.load_params(params_path)  # keep as template
    metric_objs = make_metric_objects(metrics)
    METHODS_ORDER = ["Grayscale", "Canny", "DoG", "Random"]

    # Preload a fixed list of images
    images: List[torch.Tensor] = []
    for i, (img, _) in enumerate(data_loader):
        img = img.to(DEVICE)
        images.append(img)
        if max_images is not None and len(images) >= max_images:
            break
    if len(images) == 0:
        raise RuntimeError(f"No images in loader for dataset {dataset_name}")

    rows = []

    for scheme_name, pkl_path, view_angle in schemes:
        if not os.path.exists(pkl_path):
            print(f"[skip] {scheme_name}: missing {pkl_path}")
            continue

        # Load electrode coordinates
        with open(pkl_path, 'rb') as f:
            coords = pickle.load(f)

        # Build simulator
        params = utils.load_params(params_path)
        params['run']['view_angle'] = float(view_angle)
        simulator = PhospheneSimulator(params, coords)
        amplitude = params['sampling']['stimulus_scale']
        threshold = params['thresholding']['rheobase']

        # Amplitude normalization per scheme
        stim_init = amplitude * torch.ones(simulator.num_phosphenes, device=DEVICE)
        normalizer = DynamicAmplitudeNormalizer(
            simulator=simulator,
            base_size=3, scale=1e-4,
            A_min=0, A_max=amplitude,
            learning_rate=0.002, steps=2000, target=None
        )
        _ = normalizer.run(stim_init, verbose=False)
        weights = normalizer.weights.detach()

        # SCAPE DoG
        sigma_map = build_sigma_map(simulator)
        mod_layer = SeparableModulatedConv2d(in_channels=1, sigma_map=sigma_map).to(DEVICE).eval()

        # Accumulators for mean and standard error
        n = 0
        sums = {mkey: {meth: 0.0 for meth in METHODS_ORDER} for mkey in metrics}
        sumsq = {mkey: {meth: 0.0 for meth in METHODS_ORDER} for mkey in metrics}

        for img in images:
            orig = img  # [1,1,H,W] in [0,1]

            # Build method stimuli
            stim_dict = build_methods(orig, mod_layer)

            # Simulate phosphenes
            phos_dict = {meth: simulate(simulator, stim_dict[meth], amplitude, threshold, weights)
                         for meth in METHODS_ORDER}

            # Metrics vs original
            for mkey, loss_fn in metric_objs.items():
                for meth in METHODS_ORDER:
                    val = safe_loss(loss_fn, orig, phos_dict[meth])
                    if not math.isnan(val):
                        sums[mkey][meth] += val
                        sumsq[mkey][meth] += val * val

            n += 1

        # Compute mean and se
        for meth in METHODS_ORDER:
            row = {"Method": meth, "Implant": scheme_name}
            for mkey in metrics:
                mean = sums[mkey][meth] / max(1, n)
                if include_se and n > 1:
                    var_n = max(0.0, (sumsq[mkey][meth] - n * mean * mean) / (n - 1))
                    se = math.sqrt(var_n / n)
                    row[mkey] = (mean, se)
                else:
                    row[mkey] = (mean, None) if include_se else mean
            rows.append(row)

        if verbose:
            print(f"[done] {dataset_name} × {scheme_name}: {n} images")

    # Build DataFrame
    df = pd.DataFrame(rows)
    df = df.set_index(["Method", "Implant"]).sort_index()

    # If include_se, format as mean ± se strings for display
    if include_se:
        disp = pd.DataFrame(index=df.index)
        for col in [c for c in df.columns if c not in ("Method", "Implant")]:
            def fmt(x):
                mean, se = x
                if se is None:
                    return f"{mean:.3f}"
                return f"{mean:.3f} ± {se:.3f}"
            disp[col] = df[col].map(fmt)
        df_for_tex = disp
    else:
        df_for_tex = df.applymap(lambda x: f"{x:.3f}")

    latex = df_for_tex.to_latex(
        escape=True,
        caption=f"Low level fidelity on {dataset_name}. Distances, lower is better.",
        label=f"tab:low_{dataset_name.lower()}",
        index=True,
        multicolumn=False,
        multicolumn_format='c',
        column_format='ll' + 'c' * len(df_for_tex.columns)
    )

    return df, latex

In [11]:
def main():
    # Edit these paths
    DATASET_PATHS = {
        "COCO": "/projects/prjs0344/Dynaphos/data/coco",
        "LaPa": "/projects/prjs0344/Dynaphos/data/example_faces_LaPa",
        "SUN":  "/projects/prjs0344/Dynaphos/data/SUN/SUN397",
    }

    loaders = build_dataset_loaders(DATASET_PATHS)

    # Metrics you want as columns
    METRICS = ('fsim', 'pieapp', 'content', 'srsim', 'vsi', 'mdsi')

    results: Dict[str, Tuple[pd.DataFrame, str]] = {}
    for ds_name, loader in loaders.items():
        df_table, latex = eval_dataset_to_table(
            dataset_name=ds_name,
            data_loader=loader,
            schemes=SCHEMES,
            params_path="../config/params.yaml",
            max_images=200,             # set None to use full split
            metrics=METRICS,
            include_se=True,
            verbose=True
        )
        results[ds_name] = (df_table, latex)

        csv_path = f"lowlevel_{ds_name.lower()}.csv"
        tex_path = f"lowlevel_{ds_name.lower()}.tex"
        df_table.to_csv(csv_path)
        with open(tex_path, "w") as f:
            f.write(latex)
        print(f"Wrote {csv_path} and {tex_path}")

    # Optional: print LaTeX to console
    for ds_name, (_, latex) in results.items():
        print("\n" + "=" * 80)
        print(f"{ds_name} LaTeX table")
        print("=" * 80)
        print(latex)

In [12]:
main()

loading annotations into memory...
Done (t=0.75s)
creating index...
index created!




[done] COCO × 1 Utah: 200 images
[done] COCO × Utah RFs: 200 images
[done] COCO × 4 Utah: 200 images
[done] COCO × Uniform 1024: 200 images
[done] COCO × Neuralink: 200 images
Wrote lowlevel_coco.csv and lowlevel_coco.tex




[done] LaPa × 1 Utah: 15 images
[done] LaPa × Utah RFs: 15 images
[done] LaPa × 4 Utah: 15 images
[done] LaPa × Uniform 1024: 15 images
[done] LaPa × Neuralink: 15 images
Wrote lowlevel_lapa.csv and lowlevel_lapa.tex




[done] SUN × 1 Utah: 200 images
[done] SUN × Utah RFs: 200 images
[done] SUN × 4 Utah: 200 images
[done] SUN × Uniform 1024: 200 images
[done] SUN × Neuralink: 200 images
Wrote lowlevel_sun.csv and lowlevel_sun.tex

COCO LaTeX table
\begin{table}
\caption{Low level fidelity on COCO. Distances, lower is better.}
\label{tab:low_coco}
\begin{tabular}{llcccccc}
\toprule
 &  & fsim & pieapp & content & srsim & vsi & mdsi \\
Method & Implant &  &  &  &  &  &  \\
\midrule
\multirow[t]{5}{*}{Canny} & 1 Utah & 0.499 ± 0.006 & 1.052 ± 0.091 & 6.303 ± 0.125 & 0.434 ± 0.004 & 0.221 ± 0.003 & 0.409 ± 0.002 \\
 & 4 Utah & 0.525 ± 0.007 & 1.548 ± 0.112 & 5.846 ± 0.131 & 0.480 ± 0.005 & 0.259 ± 0.004 & 0.372 ± 0.001 \\
 & Neuralink & 0.500 ± 0.007 & 2.631 ± 0.118 & 6.350 ± 0.124 & 0.421 ± 0.005 & 0.244 ± 0.003 & 0.381 ± 0.001 \\
 & Uniform 1024 & 0.457 ± 0.006 & 4.698 ± 0.135 & 4.854 ± 0.074 & 0.356 ± 0.004 & 0.230 ± 0.003 & 0.407 ± 0.001 \\
 & Utah RFs & 0.530 ± 0.008 & 1.711 ± 0.121 & 4.876 ± 0.133 

In [18]:
import pandas as pd
import numpy as np



def make_latex_table(csv_path: str,
                     dataset_name: str,
                     metrics: list,
                     caption: str = None,
                     label: str = None,
                     round_digits: int = 3) -> str:
    """
    Convert CSV of results into a LaTeX table string.
    Layout: Implant (once, multirow) | Method | Metrics.
    Bold = best (lowest) per implant and metric.
    """

    # Load
    df = pd.read_csv(csv_path)

    # Parse values
    def parse_val(v):
        if isinstance(v, str):
            v = v.strip()
            if v.startswith("(") and "," in v and v.endswith(")"):
                mean_str = v[1:-1].split(",")[0]
                try:
                    return float(mean_str)
                except:
                    return np.nan
            try:
                return float(v)
            except:
                return np.nan
        return v

    for m in metrics:
        df[m] = df[m].map(parse_val)

    # Defaults
    if caption is None:
        caption = f"Low-level fidelity metrics on {dataset_name}. Values are distances, lower is better."
    if label is None:
        label = f"tab:low_{dataset_name.lower()}"

    # Start LaTeX
    latex = []
    latex.append("\\begin{table}[t]")
    latex.append("\\centering")
    latex.append("\\small")
    latex.append(f"\\caption{{{caption}}}")
    latex.append(f"\\label{{{label}}}")
    colspec = "ll" + "c" * len(metrics)
    latex.append(f"\\begin{{tabular}}{{{colspec}}}")
    latex.append("\\toprule")
    latex.append("Implant & Method & " + " & ".join(m.upper() for m in metrics) + " \\\\")
    latex.append("\\midrule")

    # Per implant
    for implant, group in df.groupby("Implant"):
        # best values per metric across methods
        bests = {m: group[m].min() for m in metrics}

        methods = group["Method"].tolist()
        n_methods = len(methods)

        for idx, (_, row) in enumerate(group.iterrows()):
            row_str = []
            if idx == 0:
                row_str.append(f"\\multirow{{{n_methods}}}*{{{implant}}}")
            else:
                row_str.append("")  # empty for subsequent rows

            row_str.append(row["Method"])

            for m in metrics:
                val = row[m]
                if pd.isna(val):
                    s = "-"
                else:
                    sval = f"{val:.{round_digits}f}"
                    if np.isclose(val, bests[m], rtol=1e-5, atol=1e-8):
                        sval = f"\\textbf{{{sval}}}"
                    s = sval
                row_str.append(s)

            latex.append(" & ".join(row_str) + " \\\\")

        latex.append("\\midrule")

    latex.append("\\bottomrule")
    latex.append("\\end{tabular}")
    latex.append("\\end{table}")

    return "\n".join(latex)



# Example usage:
if __name__ == "__main__":
    table_tex = make_latex_table(
        csv_path="lowlevel_lapa.csv",
        dataset_name="LaPa",
        metrics=["fsim", "pieapp", "content", "srsim", "vsi", "mdsi"],
        round_digits=3
    )
    with open("table_low_lapa.tex", "w") as f:
        f.write(table_tex)
    print(table_tex)


\begin{table}[t]
\centering
\small
\caption{Low-level fidelity metrics on LaPa. Values are distances, lower is better.}
\label{tab:low_lapa}
\begin{tabular}{llcccccc}
\toprule
Implant & Method & FSIM & PIEAPP & CONTENT & SRSIM & VSI & MDSI \\
\midrule
\multirow{4}*{1 Utah} & Canny & \textbf{0.404} & \textbf{0.456} & \textbf{5.118} & \textbf{0.359} & 0.176 & \textbf{0.385} \\
 & DoG & 0.412 & 0.903 & 5.524 & 0.373 & \textbf{0.168} & 0.427 \\
 & Grayscale & 0.413 & 0.797 & 5.646 & 0.382 & 0.170 & 0.431 \\
 & Random & 0.410 & 0.881 & 5.247 & 0.375 & 0.169 & 0.427 \\
\midrule
\multirow{4}*{4 Utah} & Canny & 0.430 & \textbf{0.939} & \textbf{4.542} & \textbf{0.389} & \textbf{0.199} & \textbf{0.369} \\
 & DoG & 0.421 & 2.445 & 7.636 & 0.401 & 0.204 & 0.411 \\
 & Grayscale & \textbf{0.420} & 2.379 & 6.979 & 0.399 & 0.202 & 0.405 \\
 & Random & 0.421 & 2.426 & 7.190 & 0.404 & 0.204 & 0.409 \\
\midrule
\multirow{4}*{Neuralink} & Canny & \textbf{0.415} & \textbf{1.466} & \textbf{4.316} & \textbf{