In [None]:
# # Setup
# ! sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6
# ! pip install open-iris==1.0.0 faiss-cpu seaborn

# Imports

In [None]:
import boto3
from io import BytesIO
import pickle
import iris
import scipy
import psutil
import time
from datetime import datetime
import sys
import threading
from itertools import combinations, product
from functools import reduce
from operator import mul

In [None]:
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp, ttest_ind

In [None]:
n_jobs = 6 # Fit to CPU
DIM = (2, 32, 200)
X, Y = DIM [1:]
MAX_ROT = 15

# Data Loading

## Real Irises

In [None]:
shape = (16, 200)
# DEV = "-dev" # Access test data.
DEV = "" # Access real data.
print("Working on simulated data" if DEV else "Working on real data")

In [None]:
bucket_name = 'wld-inversed-data-sharing' + DEV
role_arn = 'arn:aws:iam::387760840988:role/worldcoin-data' + DEV
metadata_path = 'metadata.csv'

def memoize(func):
    cache = {}
    def memoized_func(*args):
        if args in cache:
            return cache[args]
        result = func(*args)
        cache[args] = result
        return result
    return memoized_func

def assume_role(role_arn, session_name="S3ReadSession"):
    sts_client = boto3.client('sts')
    assumed_role_object = sts_client.assume_role(
        RoleArn=role_arn,
        RoleSessionName=session_name
    )
    credentials = assumed_role_object['Credentials']
    s3 = boto3.client(
        's3',
        aws_access_key_id=credentials['AccessKeyId'],
        aws_secret_access_key=credentials['SecretAccessKey'],
        aws_session_token=credentials['SessionToken']
    )
    return s3

# Assume the role and get credentials
s3 = assume_role(role_arn, "S3ReadSession")

def read_s3_file(bucket_name, file_key):
    obj = s3.get_object(Bucket=bucket_name, Key=file_key)
    return BytesIO(obj['Body'].read())

@memoize
def load_response(image_id):
    " Return IrisFilterResponse "
    path = "iris_filter_responses/" + image_id + ".pickle"
    try:
        pkl_file = read_s3_file(bucket_name, path)
        return pickle.load(pkl_file)
    except Exception as err:
        print(err)
        return None

@memoize
def load_template(image_id):
    " Return IrisTemplate "
    path = "iris_templates/" + image_id + ".pickle"
    try:
        pkl_file = read_s3_file(bucket_name, path)
        return pickle.load(pkl_file)
    except Exception as err:
        print(err)
        return None

# Read the file into a DataFrame
meta = pd.read_csv(read_s3_file(bucket_name, metadata_path))

## Synthetic Irises

In [None]:
def load_synthetic_iris(method, num_samples, path='compressed_iris_matrices'):
    # Load data and randomly select num_samples samples
    loaded_data = np.load(f'{path}_{method}.npz')['data']
    assert loaded_data.shape[0] >= num_samples, f"Requested {num_samples} samples, but only {loaded_data.shape[0]} available."
    indices = np.random.choice(loaded_data.shape[0], num_samples, replace=False)
    return loaded_data[indices]

In [None]:
data_dict = dict()
for method in ['gaussian', 'voter', 'voter_gaussian']:
    data_dict[method] = load_synthetic_iris(method, 1000)

In [None]:
def import_voter_model_rust_implementation(path, total_num_samples, num_samples=None):
    num_samples = num_samples if num_samples else total_num_samples
    assert num_samples <= total_num_samples
    data = np.fromfile(path, dtype=np.uint8)
    return (
        np.unpackbits(data, bitorder="little")
        .reshape(total_num_samples, 32, 200)
        [np.random.choice(total_num_samples, size=num_samples, replace=False)]
    )

In [None]:
bryan_low_data = import_voter_model_rust_implementation('2M_voter_arrays_80k_b45.dat', 1000000, 1000)
bryan_high_data = import_voter_model_rust_implementation('2M_voter_arrays_7k_b13.dat', 1000000, 1000)
data_dict['voter_bryan'] = np.concatenate([bryan_low_data, bryan_high_data], axis=1).astype(bool)

In [None]:
synthetic_df = pd.concat(
    [pd.DataFrame({'iris_matrices':list(data), 'source':source}) for source, data in data_dict.items()], 
    ignore_index=True
)
synthetic_df['mask_matrices'] = [np.ones((reduce(mul, DIM[:2]), DIM[-1])).astype(bool)] * len(synthetic_df)

# Data Processing

## Real Irises

### Functions

In [None]:
# Helpers for iterators.
def take(count, it):
    " Take at most `count` items from the iterator `it` "
    for x in it:
        if count is not None:
            if count <= 0:
                break
            count -= 1
        yield x

In [None]:
# Load matching pairs.
def iter_matching_image_ids(meta, unique_subjects):
    " Iterate matching pairs in the form (subject_id, ir_image_id_0, ir_image_id_1). "
    subject_ids = meta["subject_id"].unique()

    for side in [0, 1]:
        meta_side = meta[meta["biological_side"] == side]

        for subject in subject_ids:
            signups = meta_side[meta_side["subject_id"] == subject]
            if len(signups) < 2:
                continue

            L = 2 if unique_subjects else len(signups)

            for i in range(L - 1):
                for j in range(i + 1, L):
                    yield (f"{subject}_side{side}", signups["ir_image_id"].iloc[i], signups["ir_image_id"].iloc[j])

def load_matching_image_ids(meta, unique_subjects):
    " Return matching pairs in the form (subject_id, ir_image_id_0, ir_image_id_1), shuffled. "    
    pair_image_ids = list(iter_matching_image_ids(meta, unique_subjects))
    rng = np.random.default_rng(seed=12345)
    rng.shuffle(pair_image_ids)
    return pair_image_ids

def iter_related_pairs(meta, unique_subjects):
    " Iterate matching pairs in the form (subject_id, response_0, response_1). "
    for (subject, img_i, img_j) in load_matching_image_ids(meta, unique_subjects):
        res_i = load_response(img_i)
        res_j = load_response(img_j)
        if res_i and res_j:
            yield (subject, res_i, res_j)

def load_related_pairs(meta, count=None, unique_subjects=False):
    " Return matching pairs in the form (subject_id, response_0, response_1). "
    return list(take(count, iter_related_pairs(meta, unique_subjects)))

In [None]:
# Masking methodologies
def fill_masked_with_random(bits, mask):
    filler = np.random.randint(0, 2, size=bits.shape, dtype=bool)
    filler &= not_(mask)
    bits ^= filler

def fill_masked_with_zeros(bits, mask):
    bits &= mask

# Techniques that do not support masking will work, although with a modified scale of distances.
# The change in distance can be calculated from the size of the overlap of masks. Alternatively,
# it can be estimated with the expected average of that.

In [None]:
# Make encoders from parameters.
def make_encoder(v_subsample=1, h_subsample=1, top=True, bottom=True, real=True, imag=True, mask_threshold=0.9, static_mask=None, mask_with_random=False):

    res_indexes = (top and [0] or []) + (bottom and [1] or [])
    assert res_indexes, "require top, bottom, or both"

    quantizers = (real and [np.real] or []) + (imag and [np.imag] or [])
    assert quantizers, "require real, imag, or both"
    
    def encode(response):
        bit_parts = []
        mask_parts = []
        
        for res_index in res_indexes:
            for quantizer in quantizers:
                res = response.iris_responses[res_index][::v_subsample, ::h_subsample]
                bits = quantizer(res) > 0
                mask = response.mask_responses[res_index][::v_subsample, ::h_subsample] >= mask_threshold

                if mask_with_random:
                    # Replace masked bits with random bits.
                    fill_masked_with_random(bits, mask)
                
                if static_mask is not None:
                    # Remove the bits not selected by the static mask.
                    fill_masked_with_zeros(bits, static_mask[::v_subsample, ::h_subsample])
                    # Treat non-selected bits as masked (False).
                    mask &= static_mask[::v_subsample, ::h_subsample]
                
                bit_parts.append(bits)
                mask_parts.append(mask)
                assert mask.shape == bits.shape

        return np.concatenate(bit_parts), np.concatenate(mask_parts)
    
    return encode

def encode_pairs(pairs, encode_fn):
    return [
        (subject_id, encode_fn(response_a), encode_fn(response_b))
        for subject_id, response_a, response_b in pairs
    ]

In [None]:
# Distances
def masked_distance(x, x_mask, y, y_mask):
    mask = x_mask & y_mask
    hd = np.sum((x ^ y) & mask)
    return hd / np.sum(mask)

def masked_rotate(x, rotation):
    return (
        np.roll(x[0], rotation, axis=1),
        np.roll(x[1], rotation, axis=1),
    )

def distance(x, y):
    return masked_distance(x[0], x[1], y[0], y[1])

def distance_raw(raw_x, raw_y):
    return distance(encode_high(raw_x), encode_high(raw_y))

def rotate_raw(raw_x, rotation):
    iris_responses = [
        np.roll(r, rotation, axis=1)
        for r in raw_x.iris_responses
    ]
    mask_responses = [
        np.roll(r, rotation, axis=1)
        for r in raw_x.mask_responses
    ]
    return iris.IrisFilterResponse(iris_responses=iris_responses, mask_responses=mask_responses)

In [None]:
# Rotations.
def without_rotation(pairs, distance_fn, rotate_fn, max_rotation):
    for subject_id, x, y in pairs:
        distances = [
            distance_fn(x, rotate_fn(y, rotation))
            for rotation in range(-max_rotation, max_rotation+1)
        ]
        best_rotation = -max_rotation + np.argmin(distances)        
        y_aligned = rotate_fn(y, best_rotation)
        yield (subject_id, x, y_aligned)

def remove_rotation(pairs, distance_fn=distance, rotate_fn=masked_rotate, max_rotation=15):
    return list(without_rotation(pairs, distance_fn, rotate_fn, max_rotation))

In [None]:
def plot_boolean_iris(matrix, title=''):
    plt.imshow(matrix, cmap='gray')
    plt.title(title)
    plt.show()

### Loading

In [None]:
# 6 min
encode_high = make_encoder()
related_pairs = load_related_pairs(meta, count=None, unique_subjects=False)
related_pairs_norot = remove_rotation(related_pairs, distance_fn=distance_raw, rotate_fn=rotate_raw)
related_pairs_high = encode_pairs(related_pairs_norot, encode_high)
shape_high = related_pairs_high[0][1][0].shape
print(f"Finished loading {len(related_pairs_high)} pairs,", "High-res", shape_high, np.prod(shape_high), "bits")

In [None]:
tuples_array = np.array(related_pairs_high, dtype=object)
subject_ids = np.repeat(tuples_array[:, 0], 2)  # Repeat each subject_id twice
flattened_result = [item for tup in tuples_array for item in tup[1:]]
iris_matrices, mask_matrices = zip(*flattened_result)

In [None]:
true_iris_df = pd.DataFrame({
    'subject_id': subject_ids,
    'iris_matrices': iris_matrices,
    'mask_matrices': mask_matrices
})
true_iris_df['side'] = true_iris_df['subject_id'].apply(lambda x: x[-1])
true_iris_df['subject_id'] = true_iris_df['subject_id'].apply(lambda x: x.split('_')[0])

In [None]:
# Dropping duplicates
true_iris_df['iris_matrices_bytes'] = true_iris_df['iris_matrices'].apply(lambda matrix: matrix.tobytes())
true_iris_df['mask_matrices_bytes'] = true_iris_df['mask_matrices'].apply(lambda matrix: matrix.tobytes())
true_iris_df = (
    true_iris_df
    .drop_duplicates(subset=['subject_id', 'iris_matrices_bytes', 'mask_matrices_bytes'])
    .drop(columns=['iris_matrices_bytes', 'mask_matrices_bytes'])
    .reset_index(drop=True)
)
print(f'Final iris DataFrame contains {len(true_iris_df)} unique samples')

# Noise Analysis

In [None]:
def stack_rotated_matrices(matrices, max_rotation):
    return np.vstack([
        np.roll(matrix, shift, axis=0).flatten()
        for matrix, shift in product(matrices, range(-max_rotation, max_rotation + 1))
    ])

def get_pairwise_min_dist_across_rotations(iris_matrices, mask_matrices, max_rotation, lim_group_size=50):
    if len(iris_matrices) > lim_group_size:
        iris_matrices = iris_matrices.sample(lim_group_size)
        mask_matrices = mask_matrices[iris_matrices.index]

    # Create rotated matrices and masks
    rotated_matrices = stack_rotated_matrices(iris_matrices, max_rotation)
    rotated_masks = stack_rotated_matrices(mask_matrices, max_rotation)
    
    # Calculate pairwise Hamming distances considering only True values in the mask
    valid_positions = np.expand_dims(rotated_masks, axis=1) & np.expand_dims(rotated_masks, axis=0)
    differences = np.expand_dims(rotated_matrices, axis=1) != np.expand_dims(rotated_matrices, axis=0)
    hamming_distances = np.sum(differences & valid_positions, axis=-1) / np.sum(valid_positions, axis=-1)
    
    # Mask self-comparisons with np.inf
    matrix_indices = np.arange(len(iris_matrices)).repeat(2 * max_rotation + 1)
    hamming_distances[matrix_indices[:, None] == matrix_indices[None, :]] = np.inf
    
    # Reshape and find minimum distances
    reshaped_distances = hamming_distances.reshape(len(iris_matrices), 2 * max_rotation + 1, len(iris_matrices), 2 * max_rotation + 1)
    min_distances_per_matrix = np.min(reshaped_distances, axis=(1, 3))

    # Extract only the lower triangle (excluding the diagonal)
    return min_distances_per_matrix[np.tril_indices(len(iris_matrices), k=-1)]

In [None]:
results = []
for (subject_id, side), group in true_iris_df.groupby(['subject_id', 'side']):
    results.append(get_pairwise_min_dist_across_rotations(group['iris_matrices'], group['mask_matrices'], max_rotation=MAX_ROT))
nearest_pairwise_dist_w_rotations = np.concatenate(results)

In [None]:
plt.figure(figsize=(13,6))
sns.histplot(nearest_pairwise_dist_w_rotations, stat='probability', bins=100, color='#BD2A2E')
plt.title('Same Samples Distance Distribution (Noise)', fontsize=15, y=1.08)
plt.grid()
plt.show()

In [None]:
def calculate_distribution(data, num_bins=100):
    counts, bin_edges = np.histogram(data, bins=num_bins, density=True)
    midpoints = (bin_edges[:-1] + bin_edges[1:]) / 2
    probabilities = counts / np.sum(counts)
    return midpoints, probabilities
    
def sample_from_distribution(midpoints, probabilities, sample_size=10):
    return np.random.choice(midpoints, size=sample_size, p=probabilities)

In [None]:
midpoints, probabilities = calculate_distribution(nearest_pairwise_dist_w_rotations)

In [None]:
sample_from_distribution(midpoints, probabilities)

In [None]:
# np.savez_compressed('noise_distribution_bin_midpoints.npz', data=midpoints)
# np.savez_compressed('noise_distribution_probability_distribution.npz', data=probabilities)

# Short Mask Analysis

In [None]:
def process_masks(mask_matrix):
    # Adjust matrix to be built by 8 different masks
    separated_matrix = mask_matrix.reshape(4, 16, 2, 100).transpose(2, 0, 1, 3)

    # Calculate bottom row stats
    inverted_last_rows = ~separated_matrix[:, :, -1, :]
    mean_true_counts = inverted_last_rows.sum(axis=2).mean(axis=1) # Mean length
    first_true_indices = np.where(
        inverted_last_rows.any(axis=-1), np.argmax(inverted_last_rows, axis=-1), np.nan
    )
    last_true_indices = np.where(
        inverted_last_rows.any(axis=-1), inverted_last_rows.shape[-1] - 1 - np.argmax(inverted_last_rows[:, :, ::-1], axis=-1), np.nan
    )
    mean_middle_indices = np.nanmean((first_true_indices + last_true_indices) / 2, axis=1) # Mean middle index

    # Calculate longest column stats
    true_counts = (~separated_matrix).sum(axis=2)
    max_true_counts = np.max(true_counts, axis=2)
    mean_max_true_counts = max_true_counts.mean(axis=1)
    return (*mean_true_counts, *mean_middle_indices, *mean_max_true_counts)

In [None]:
processed_mask_col_names = ['left_h_len', 'right_h_len', 'left_mid_ind', 'right_mid_ind', 'left_v_len', 'right_v_len']
processed_mask_df = pd.DataFrame(
    true_iris_df['mask_matrices'].apply(process_masks).tolist(), 
    columns=processed_mask_col_names, 
    index=true_iris_df.index
)
true_iris_df[processed_mask_col_names] = processed_mask_df

In [None]:
processed_mask_df = pd.wide_to_long(
    processed_mask_df.reset_index(), 
    stubnames=['left', 'right'], 
    i='index', 
    j='metric', 
    suffix='(h_len|mid_ind|v_len)', 
    sep='_'
)
processed_mask_df = (
    pd.melt(processed_mask_df.reset_index(), id_vars=processed_mask_df.index.names, var_name='side')
    .drop(columns='index')
)

In [None]:
facetgrid = sns.FacetGrid(processed_mask_df, col='metric', row='side', sharex=False, sharey=False, height=4, aspect=1.4)
facetgrid.map_dataframe(sns.histplot, x='value', stat='probability', color='#019587', kde=True)
[ax.grid(True) for ax in facetgrid.axes.flat]
facetgrid.fig.suptitle("Iris masks derived distributions", fontsize=20, y=1.03)
plt.show()

In [None]:
mask_filtered = processed_mask_df[processed_mask_df['metric'] != 'mid_ind']
zero_mask_perc = mask_filtered.groupby(['side']).apply(lambda group: (group['value'] == 0).mean())
zero_mask_perc.rename('Percentage of no apparent mask').to_frame()

In [None]:
processed_mask_df[processed_mask_df['value'] > 0 ].groupby(['side', 'metric']).agg({'mean', 'std'})

# Synthethic Data Quality Tests

## Constants

In [None]:
alpha = 0.05

## Data Merging

In [None]:
all_iris_df = pd.concat([true_iris_df.assign(source='real'), synthetic_df], ignore_index=True) 

## Rotation Test

In [None]:
# results = []
for rot in range(-MAX_ROT, MAX_ROT+1): 
    if rot == 0: # No rotation - distance is 0
        continue

    all_iris_df[f'low_{rot}'], all_iris_df[f'high_{rot}'] = zip(*all_iris_df['iris_matrices'].apply(
        lambda matrix: [np.sum(part != np.roll(part, shift=rot, axis=1)) / part.size for part in np.split(matrix, 2, axis=0)]
    ))

    # Statistical Tests
    for (source1, group1), (source2, group2) in combinations(all_iris_df.groupby('source'), 2):
        for wavelength in ['low', 'high']:
            group1_wavelength = group1[f'{wavelength}_{rot}']
            group2_wavelength = group2[f'{wavelength}_{rot}']

            ks_stat, ks_p_value = ks_2samp(group1_wavelength, group2_wavelength) # Kolmogorov-Smirnov test
            t_stat, t_p_value = ttest_ind(group1_wavelength, group2_wavelength) # Student's t test
            
            # Store the results in a list (you can store anything you want here)
            results.append({
                'Rotation':rot,
                'wavelength':wavelength,
                'first_source':source1,
                'second_source':source2,
                'ks_stat':ks_stat,
                't_stat':t_stat,
                'ks_p_value':ks_p_value,
                't_p_value':t_p_value,
            })

results_df = pd.DataFrame(results)
results_df['passed_KS_test'] = results_df['ks_p_value'] >= alpha
results_df['passed_t_test'] = results_df['t_p_value'] >= alpha

In [None]:
results_df = (
    results_df
    .groupby(['first_source', 'second_source', 'wavelength'])[['ks_p_value', 'passed_KS_test', 't_p_value', 'passed_t_test']]
    .mean()
    .reset_index()
)
real_mask = results_df[['first_source', 'second_source']].isin(['real']).any(axis=1)
results_df = results_df[real_mask]
compared_method = np.where(results_df['first_source'] == 'real', results_df['second_source'], results_df['first_source'])
results_df = (
    results_df
    .assign(compared_method=compared_method)
    .drop(columns=['first_source', 'second_source'])
    .reset_index(drop=True)
)

In [None]:
results_df.groupby(['compared_method', 'wavelength']).first()

In [None]:
plot_df = pd.melt(
    all_iris_df,
    id_vars='source',
    value_vars=[col for col in all_iris_df.columns if 'low_' in col or 'high_' in col],
    var_name='wavelength_rotation',
    value_name='Hamming Distance',
)
plot_df[['Wavelength', 'Rotation']] = plot_df['wavelength_rotation'].str.split('_', expand=True)
plot_df.drop(columns=['wavelength_rotation'], inplace=True)
plot_df['Rotation'] = plot_df['Rotation'].astype(int)
plot_df['Source'] = plot_df['source'] + ', ' + plot_df['Wavelength'] + ' wavelength'

In [None]:
dist_linear_analysis = (
    plot_df
    .groupby(['Rotation', 'Wavelength', 'source'])['Hamming Distance']
    .agg({'mean', 'std'})
    .reset_index()
).rename(columns={'source':'Source'})
dist_linear_analysis = pd.melt(
    dist_linear_analysis,
    id_vars=['Rotation', 'Wavelength', 'Source'],
    value_vars=['std', 'mean'],
    var_name='Metric'
)

In [None]:
for source in ['gaussian', 'voter', 'voter_bryan', 'voter_gaussian']:
    mask = dist_linear_analysis['Source'].isin(['real']+[source])
    facetgrid = sns.FacetGrid(dist_linear_analysis[mask], col='Metric', height=5, aspect=2, sharex=False, sharey=False)
    facetgrid.map_dataframe(sns.lineplot, x='Rotation', y='value', hue='Wavelength', palette='husl', style='Source')
    [(ax.grid(True), ax.legend()) for ax in facetgrid.axes.flat]
    facetgrid.fig.suptitle(f"Mean and Std of real and {source.replace('_', ' ')} iris samples, in relation to rotation", fontsize=15, y=1.05)
    plt.show()

In [None]:
for source in ['gaussian', 'voter', 'voter_gaussian', 'voter_bryan']:
    mask = plot_df['source'].isin(['real']+[source])
    facetgrid = sns.FacetGrid(plot_df[mask], col='Rotation', hue='Source', palette='husl', col_wrap=5, sharex=False)
    facetgrid.map_dataframe(sns.histplot, x='Hamming Distance', stat='probability')
    [ax.grid(True) for ax in facetgrid.axes.flat]
    facetgrid.add_legend()
    facetgrid.fig.suptitle(f"Distance distribution upon rotation\nReal irises to {source.replace('_', ' ')} distributions", fontsize=20, y=1.03)
    plt.show()

In [None]:
real_iris_samples_df = all_iris_df[all_iris_df['source'] == 'real'].drop_duplicates(subset=['subject_id', 'side'])
real_iris_samples_df['mean_bw_ratio_low'], real_iris_samples_df['mean_bw_ratio_high'] = zip(*real_iris_samples_df['iris_matrices'].apply(
    lambda matrix: [part.mean(axis=0) for part in np.split(matrix, 2, axis=0)]
))

In [None]:
for wavelength in ['low', 'high']:
    bw_ratio_matrix = np.column_stack(real_iris_samples_df[f'mean_bw_ratio_{wavelength}'].values)
    mean_values = bw_ratio_matrix.mean(axis=1)
    ci = 1.96 * bw_ratio_matrix.std(axis=1) / np.sqrt(bw_ratio_matrix.shape[1])  # 95% CI
    
    plt.figure(figsize=(10, 6))
    sns.lineplot(x=range(bw_ratio_matrix.shape[0]), y=mean_values, errorbar=None, label='Mean True / False Ratio')
    plt.fill_between(range(bw_ratio_matrix.shape[0]), mean_values - ci, mean_values + ci, color='b', alpha=0.3, label='95% CI')
    plt.title(f"Mean True / False Ratio, {wavelength} wavelength with 95% Confidence Interval")
    plt.xlabel("Iris matrix x-axis")
    plt.ylabel("True / False Ratio")
    plt.legend()
    plt.show()

## Boolean Ratio Test

In [None]:
all_iris_df['Boolean Ratio'] = all_iris_df['iris_matrices'].apply(lambda matrix: matrix.mean())
stats_df = all_iris_df.groupby('source')['Boolean Ratio'].agg({'mean', 'std'})

In [None]:
facetgrid = sns.FacetGrid(all_iris_df, col='source', sharex=False, sharey=False)
facetgrid.map_dataframe(sns.histplot, x='Boolean Ratio', stat='probability', kde=True, color='#FF7A48')
title_template = "Source: {col_name}\nMean: {mean:.2f}, Std: {std:.2f}"
facetgrid.set_titles(col_template="{col_name}")
for ax, col_value in zip(facetgrid.axes.flat, facetgrid.col_names):
    mean = stats_df.loc[col_value, 'mean']
    std = stats_df.loc[col_value, 'std']
    ax.set_title(title_template.format(col_name=col_value, mean=mean, std=std))
    ax.grid(True)
facetgrid.fig.suptitle(f"True / False Ratio Validation", fontsize=20, y=1.2)
plt.show()

## Nearest to Random Dist Test

In [None]:
def stack_rotated_matrices(matrices, max_rotation):
    return np.vstack([
        np.roll(matrix, shift, axis=0).flatten()
        for matrix in matrices
        for shift in range(-max_rotation, max_rotation + 1)
    ])

def get_min_and_random_dist_across_rotations(iris_matrices, mask_matrices, max_rotation):
    num_matrices = len(iris_matrices)
    num_rotations = 2 * max_rotation + 1

    # Rotate matrices and masks, reshape to (num_matrices, num_rotations, flattened_size)
    rotated_matrices = stack_rotated_matrices(iris_matrices, max_rotation).reshape(num_matrices, num_rotations, -1)
    rotated_masks = stack_rotated_matrices(mask_matrices, max_rotation).reshape(num_matrices, num_rotations, -1)

    closest_distances, random_distances = [], []
    for i in range(num_matrices):
        # Current matrix rotations and masks
        current_rotated_matrix = rotated_matrices[i]
        current_rotated_mask = rotated_masks[i]

        # Extract other matrices' rotations excluding the current
        other_rotated_matrices = np.delete(rotated_matrices, i, axis=0).reshape(-1, rotated_matrices.shape[-1])
        other_rotated_masks = np.delete(rotated_masks, i, axis=0).reshape(-1, rotated_masks.shape[-1])

        # Calculate valid positions and Hamming distances
        valid_positions = current_rotated_mask[:, None] & other_rotated_masks
        differences = current_rotated_matrix[:, None] != other_rotated_matrices

        # Calculate Hamming distances
        hamming_distances = np.sum(differences & valid_positions, axis=-1) / np.sum(valid_positions, axis=-1)
        
        # Find the minimum distance and a random distance
        closest_distances.append(np.min(hamming_distances))
        random_distances.append(np.random.choice(hamming_distances.flatten()))

    return pd.DataFrame({"closest_dist": closest_distances, "random_dist": random_distances}, index=iris_matrices.index)

In [None]:
real_iris_samples_df = all_iris_df[all_iris_df['source'] == 'real'].drop_duplicates(subset=['subject_id', 'side'])
balanced_non_real_samples = (
    all_iris_df[all_iris_df['source'] != 'real']
    .groupby('source')
    .apply(lambda group: group.sample(len(real_iris_samples_df)))
)
sub_iris_df = pd.concat([real_iris_samples_df, balanced_non_real_samples], ignore_index=True)
sub_iris_df.groupby('source').size()

In [None]:
results = []
for source, group in sub_iris_df.groupby('source'):
    results.append(get_min_and_random_dist_across_rotations(group['iris_matrices'], group['mask_matrices'], max_rotation=MAX_ROT).assign(source = source))
results = pd.concat(results)
results.head()

In [None]:
plot_df = pd.melt(
    results.rename(columns={'source':'Source'}), 
    id_vars='Source', 
    value_vars=['random_dist', 'closest_dist'],
    var_name='Distance From',
    value_name='Hamming Distance'
)
plot_df['Distance From'] = plot_df['Distance From'].apply(lambda x: x.split('_')[0].capitalize())

In [None]:
facetgrid = sns.FacetGrid(plot_df, col='Distance From', hue='Source', palette='husl', height=4, aspect=2, sharex=False, sharey=False)
facetgrid.map_dataframe(sns.histplot, x='Hamming Distance', stat='probability', kde=True)
[ax.grid(True) for ax in facetgrid.axes.flat]
facetgrid.add_legend()
facetgrid.fig.suptitle(f"Distance from Random / Nearest iris, by data source", fontsize=20, y=1.2)
plt.show()

## Comparing pair-wise distance distributions (to Daugman survey)

In [None]:
non_real_iris_samples = all_iris_df[all_iris_df['source'] != 'real']
real_iris_samples_df = all_iris_df[all_iris_df['source'] == 'real'].drop_duplicates(subset=['subject_id', 'side'])
sub_iris_df = pd.concat([non_real_iris_samples, real_iris_samples_df], ignore_index=True)
sub_iris_df.groupby('source').size()

In [None]:
def get_pairwise_dist_vector(iris_matrices):
    reshaped_matrices = np.stack(iris_matrices.values).reshape(len(iris_matrices),-1)
    neq_matrices = np.expand_dims(reshaped_matrices, axis=1) != np.expand_dims(reshaped_matrices, axis=0)
    dist_vector = (np.sum(neq_matrices, axis=2) / reshaped_matrices.shape[1]).flatten()
    return dist_vector[dist_vector > 0]

In [None]:
distances_series = sub_iris_df.groupby('source').apply(lambda group: get_pairwise_dist_vector(group['iris_matrices']))
sources = np.repeat(distances_series.index.values, distances_series.str.len())
distances = np.concatenate(distances_series.values)
plot_df = pd.DataFrame({'Source': sources, 'Distances': distances})
plot_df = (
    plot_df
    .groupby('Source', group_keys=False)
    .apply(lambda group: group.sample((plot_df['Source'] == 'real').sum()))
    .reset_index(drop=True)
)

In [None]:
sns.histplot(plot_df, x='Distances', stat='probability', hue='Source', palette='husl', kde=True)
plt.grid(True)
plt.title('Pairwise Distance Distribution', fontsize=15, y=1.07)
plt.show()

In [None]:
stats_df = plot_df.groupby('Source')['Distances'].agg({'mean', 'std'})
stats_df['N'] = (stats_df['mean'] * (1 - stats_df['mean'])) / stats_df['std']**2

In [None]:
facetgrid = sns.FacetGrid(plot_df, col='Source', sharex=False, sharey=False)
facetgrid.map_dataframe(sns.histplot, x='Distances', stat='probability', kde=True, color='#FF7A48')
title_template = "Source: {col_name}\nMean: {mean:.2f}, Std: {std:.2f}, N: {N:.0f}"
facetgrid.set_titles(col_template="{col_name}")
for ax, col_value in zip(facetgrid.axes.flat, facetgrid.col_names):
    mean, std, N = stats_df.loc[col_value, ['mean', 'std', 'N']] 
    ax.set_title(title_template.format(col_name=col_value, mean=mean, std=std, N=N))
    ax.grid(True)
facetgrid.fig.suptitle(f"Pairwise Distance Distribution\nStats from Daugman survey - Mean: 0.499, Std: 0.0317, N=249", fontsize=20, y=1.4)
plt.show()