<div style="background-color: darkred; padding: 10px; color: white;">

# Test Soma Classifier

</div>

<div style="background-color: darkblue; padding: 10px; color: white;">

## Section 1: Load Testing Data
    
</div>

### Imports

In [1]:
from scipy.optimize import OptimizeWarning
from tqdm import tqdm

import ast
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import warnings

from aind_exaspim_soma_detection import soma_proposal_classification as spc
from aind_exaspim_soma_detection.utils import data_util, img_util, ml_util, util
from aind_exaspim_soma_detection.machine_learning import data_handling
from aind_exaspim_soma_detection.machine_learning.data_handling import MultiThreadedDataLoader, ProposalDataset

warnings.filterwarnings("ignore", category=OptimizeWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

%matplotlib inline


KeyboardInterrupt



### Initializations

In [None]:
# Parameters
anisotropy = [0.748, 0.748, 1.0]
multiscale = 1
patch_shape = (102, 102, 102)

# Paths
img_lookup_path =  "/root/capsule/data/exaspim_image_prefixes.json"
smartsheet_path = "/root/capsule/data/Neuron_Reconstructions.xlsx"
exaspim_soma_path = "/root/capsule/data/exaspim_somas_2024"

### Load Data

In [None]:
# SmartSheet Data
smartsheet_data = data_util.fetch_smartsheet_somas(
    smartsheet_path,
    img_lookup_path,
    multiscale,
)

# ExaSPIM Data
exapsim_data = data_util.fetch_exaspim_somas_2024(
    exaspim_soma_path,
    img_lookup_path,
    multiscale,
)

# Load Additional Examples
ignore = data_util.load_examples("/root/capsule/data/ignore.txt")
test_positives = data_util.load_examples("/root/capsule/data/test_positives.txt")
test_negatives = data_util.load_examples("/root/capsule/data/test_negatives.txt")


### Create Dataset

In [None]:
# Initialize Dataset
dataset = ProposalDataset(patch_shape)
for proposals_tuple in smartsheet_data + exapsim_data:
    dataset.ingest_proposals(*proposals_tuple)

# Remove Training Examples
keys = list(dataset.proposals.keys())
test_proposals = set(test_positives + test_negatives)
for key in keys:
    if key not in test_proposals or key in ignore:
        dataset.remove_proposal(key, epsilon=10)

# Report Dataset Specs
print("# Examples:", len(dataset))
print("# Positive Examples:", dataset.n_positives())
print("# Negative Examples:", dataset.n_negatives())

<div style="background-color: darkblue; padding: 10px; color: white;">

## Section 2: Run Model
    
</div>

In [None]:
# Parameters
batch_size = 16
device = "cuda"
model_path = "/root/capsule/results/soma_classifiers_2025-01-11 21:07:42.683209/model_8640_f1=0.93.pth"

# Initializations
dataloader = MultiThreadedDataLoader(dataset, batch_size)
model = spc.load_model(model_path, patch_shape, device)

# Main
total = len(dataset) // batch_size
keys, y, hat_y = list(), list(), list()
with torch.no_grad():
    for keys_i, x_i, y_i in tqdm(dataloader, total=total):
        # Forward pass
        x_i = x_i.to(device)
        hat_y_i = torch.sigmoid(model(x_i))

        # Store prediction
        keys.extend(keys_i)
        y.append(np.array(y_i))
        hat_y.append(np.array(hat_y_i.detach().cpu()))

# Reformat predictions
hat_y = np.vstack(hat_y)[:, 0]
y = np.vstack(y)[:, 0]

### Visualize Predictions

In [None]:
# Distribution of all predictions
plt.figure(figsize=(10, 5))
plt.hist(hat_y, alpha=0.9, bins=30, color="tab:blue", edgecolor='black')
plt.title('Distribution of Predictions')
plt.xlabel('Prediction')
plt.ylabel('Count')
plt.show()

In [None]:
# Extract positives and negatives
hat_y_positives = np.array([hat_y_i for i, hat_y_i in enumerate(hat_y) if y[i] == 1])
hat_y_negatives = np.array([hat_y_i for i, hat_y_i in enumerate(hat_y) if y[i] == 0])

# Distribution of positives vs. negatives
plt.figure(figsize=(10, 5))
plt.hist(hat_y_positives, alpha=0.9, bins=30, color="tab:blue", edgecolor='black', label="Positives")
plt.hist(hat_y_negatives, alpha=0.9, bins=30, color="tab:orange", edgecolor='black', label="Negatives")
plt.title('Distribution of Postive vs. Negative Predictions', fontsize=14)
plt.xlabel('Predictions', fontsize=13)
plt.ylabel('Counts', fontsize=13)
plt.legend(fontsize=11)
plt.grid(True)
plt.show()

<div style="background-color: darkblue; padding: 10px; color: white;">

## Section 3: Quantitative Results
    
</div>

### Subroutines

In [None]:
def get_incorrect(keys, y, hat_y, threshold):
    # Extract incorrect
    hat_y = (hat_y > threshold).astype(int)
    incorrect = {"false_negatives": list(), "false_positives": list()}
    for i, (y_i, hat_y_i) in enumerate(zip(y, hat_y)):
        if y_i == 1 and hat_y_i == 0:
            incorrect["false_negatives"].append(keys[i])
        elif y_i == 0 and hat_y_i == 1:
            incorrect["false_positives"].append(keys[i])

    # Report results
    n_false_negatives = len(incorrect["false_negatives"])
    n_false_positives = len(incorrect["false_positives"])
    print(f"# False Positives: {n_false_positives}")
    print(f"# False Negatives: {n_false_negatives}")
    return incorrect


def get_correct(keys, y, hat_y, threshold):
    # Extract incorrect
    hat_y = (hat_y > threshold).astype(int)
    correct = {"true_negatives": list(), "true_positives": list()}
    for i, (y_i, hat_y_i) in enumerate(zip(y, hat_y)):
        if y_i == 0 and hat_y_i == 0:
            correct["true_negatives"].append(keys[i])
        elif y_i == 1 and hat_y_i == 1:
            correct["true_positives"].append(keys[i])

    # Report results
    n_true_negatives = len(correct["true_negatives"])
    n_true_positives = len(correct["true_positives"])
    print(f"# True Positives: {n_true_positives}")
    print(f"# True Negatives: {n_true_negatives}")
    return correct


### Compute Performance Metrics

In [None]:
# Parameters
confidence_threshold = 0.3

# Reformat predictions
y = np.vstack(y)
hat_y = np.vstack(hat_y)

# Results
ml_util.report_metrics(y, hat_y, confidence_threshold)
correct = get_correct(keys, y, hat_y, confidence_threshold)
incorrect = get_incorrect(keys, y, hat_y, confidence_threshold)


### Visualize Correct Predictions

In [None]:
key = util.sample_once(correct["true_positives"])
print("Example ID:", key)
dataset.visualize_proposal(key)

### Visualize Incorrect Predictions

In [None]:
key = incorrect["false_negatives"][18]  #util.sample_once(incorrect["false_negatives"])
print("Example ID:", key)
dataset.visualize_proposal(key)