In [167]:
from sqlalchemy import create_engine
import pandas as pd
from umap.distances import matching

from data_generator.main import DiscriminationData
from path import HERE

DB_PATH = HERE.joinpath('experiments/discrimination_detection_results5.db')
conn = create_engine(f'sqlite:///{DB_PATH}')

In [168]:
import json

random_experiment_id = '67705f0b-3035-41d6-a111-8289fc8b22f1'
df_synth1 = pd.read_sql_query(
    f"SELECT experiment_id, full_data, attr_possible_values FROM synthetic_data where experiment_id='{random_experiment_id}'",
    conn)
df_synth1 = df_synth1.copy()
df_synth = pd.DataFrame(json.loads(df_synth1['full_data'].iloc[0]))
df_synth['experiment_id'] = random_experiment_id
df_synth['attr_possible_values'] = df_synth1['attr_possible_values'].iloc[0]

In [169]:
df_synth['group_key'].unique()

array(['0|*|2-1|*|2', '1|0|*-0|0|*'], dtype=object)

In [170]:
df_result = pd.read_sql_query(
    f"SELECT * FROM augmented_results ar left join main.analysis_metadata am on am.analysis_id=ar.analysis_id where experiment_id='{random_experiment_id}'",
    conn)
df_result_data = pd.DataFrame(list(df_result['data'].apply(lambda x: json.loads(x))))
df_result = pd.concat([df_result.reset_index(drop=True), df_result_data.reset_index(drop=True)], axis=1)
df_result.head()

Unnamed: 0,analysis_id,indv_key,couple_key,is_original_data,is_couple_part_of_a_group,matching_groups,data,analysis_id.1,experiment_id,method_name,created_at,Attr1_T,Attr1_X,Attr2_X,outcome
0,dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,2|2|2,2|2|2-1|2|2,1,0,[],"{""Attr1_T"": 2, ""Attr1_X"": 2, ""Attr2_X"": 2, ""ou...",dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,67705f0b-3035-41d6-a111-8289fc8b22f1,aequitas,2024-11-28 20:25:51.000623,2,2,2,2
1,dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,1|2|2,1|2|2-2|2|2,1,0,[],"{""Attr1_T"": 1, ""Attr1_X"": 2, ""Attr2_X"": 2, ""ou...",dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,67705f0b-3035-41d6-a111-8289fc8b22f1,aequitas,2024-11-28 20:25:51.000623,1,2,2,1
2,dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,1|1|2,1|1|2-0|1|2,1,0,[],"{""Attr1_T"": 1, ""Attr1_X"": 1, ""Attr2_X"": 2, ""ou...",dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,67705f0b-3035-41d6-a111-8289fc8b22f1,aequitas,2024-11-28 20:25:51.000623,1,1,2,3
3,dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,0|1|2,0|1|2-1|1|2,1,1,"[""0|*|2-1|*|2""]","{""Attr1_T"": 0, ""Attr1_X"": 1, ""Attr2_X"": 2, ""ou...",dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,67705f0b-3035-41d6-a111-8289fc8b22f1,aequitas,2024-11-28 20:25:51.000623,0,1,2,2
4,dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,1|1|3,1|1|3-0|1|3,0,0,[],"{""Attr1_T"": 1, ""Attr1_X"": 1, ""Attr2_X"": 3, ""ou...",dd5bc81f-c6b3-48f9-8c9d-3b2c4b8b3cf9,67705f0b-3035-41d6-a111-8289fc8b22f1,aequitas,2024-11-28 20:25:51.000623,1,1,3,3


In [171]:
feature_cols = list(filter(lambda x: 'Attr' in x, df_synth.columns))
feature_cols

['Attr1_X', 'Attr2_X', 'Attr1_T']

In [172]:
from itertools import combinations
import pandas as pd
from typing import List, Dict, Any


def sort_two_strings(str1: str, str2: str) -> tuple[str, str]:
    if str1 <= str2:
        return (str1, str2)
    return (str2, str1)


def prepare_result_combinations(df: pd.DataFrame, feature_cols: List[str]) -> pd.DataFrame:
    all_combinations = []

    for couple_key in df['couple_key'].unique():
        # Sort the individual keys from the couple_key
        indivs = couple_key.split('-')
        sorted_indivs = sort_two_strings(indivs[0], indivs[1])

        # Create a new sorted couple_key
        sorted_couple_key = f"{sorted_indivs[0]}-{sorted_indivs[1]}"

        # Fetch data for each individual separately in sorted order
        indiv1_data = df[df['indv_key'] == sorted_indivs[0]]
        indiv2_data = df[df['indv_key'] == sorted_indivs[1]]

        # Combine the data in sorted order
        couple_data = pd.concat([indiv1_data, indiv2_data])

        if couple_data.shape[0] < 2:
            continue

        is_part_of_group = couple_data['is_couple_part_of_a_group'].iloc[0] != '0'
        matching_groups = couple_data['matching_groups'].iloc[0]
        unique_individuals = couple_data[feature_cols].drop_duplicates().values
        pairs = list(combinations(range(len(unique_individuals)), 2))

        for i, j in pairs:
            combination = {
                'couple_key': sorted_couple_key,  # Use the sorted couple_key
                'is_part_of_group': is_part_of_group,
                'matching_groups': matching_groups,
            }

            # Store individual feature values
            indiv1_features = unique_individuals[i]
            indiv2_features = unique_individuals[j]

            # Add the features in sorted order
            for idx, feat in enumerate(feature_cols):
                combination[f'{feat}_1'] = indiv1_features[idx]
                combination[f'{feat}_2'] = indiv2_features[idx]

            all_combinations.append(combination)
    res = pd.DataFrame(all_combinations)
    res['group_key'] = res['matching_groups'].apply(json.loads)
    res = res.explode('group_key').reset_index(drop=True)
    return res

In [173]:
synth_combinations_df = DiscriminationData.generate_individual_synth_combinations(df_synth)
synth_combinations_df

Unnamed: 0,group_key,subgroup1_key,subgroup2_key,indv_key_1,indv_key_2,couple_key,Attr1_X_1,Attr2_X_1,Attr1_T_1,Attr1_X_2,Attr2_X_2,Attr1_T_2
0,0|*|2-1|*|2,0|*|2,1|*|2,0|1|2,1|0|2,0|1|2-1|0|2,0,1,2,1,0,2
11,0|*|2-1|*|2,0|*|2,1|*|2,0|1|2,1|1|2,0|1|2-1|1|2,0,1,2,1,1,2
17,0|*|2-1|*|2,0|*|2,1|*|2,0|1|2,1|2|2,0|1|2-1|2|2,0,1,2,1,2,2
108,0|*|2-1|*|2,0|*|2,1|*|2,0|1|2,1|3|2,0|1|2-1|3|2,0,1,2,1,3,2
117,0|*|2-1|*|2,0|*|2,1|*|2,0|2|2,1|0|2,0|2|2-1|0|2,0,2,2,1,0,2
...,...,...,...,...,...,...,...,...,...,...,...,...
35372,1|0|*-0|0|*,0|0|*,1|0|*,1|2|3,2|3|2,1|2|3-2|3|2,1,2,3,2,3,2
35376,1|0|*-0|0|*,0|0|*,1|0|*,1|3|0,2|3|2,1|3|0-2|3|2,1,3,0,2,3,2
35386,1|0|*-0|0|*,0|0|*,1|0|*,1|3|1,2|3|2,1|3|1-2|3|2,1,3,1,2,3,2
35392,1|0|*-0|0|*,0|0|*,1|0|*,1|3|2,2|3|2,1|3|2-2|3|2,1,3,2,2,3,2


In [174]:
# Create combinations
result_combinations_df = prepare_result_combinations(df_result, feature_cols)
result_combinations_df

Unnamed: 0,couple_key,is_part_of_group,matching_groups,Attr1_X_1,Attr1_X_2,Attr2_X_1,Attr2_X_2,Attr1_T_1,Attr1_T_2,group_key
0,1|2|2-2|2|2,False,[],2,2,2,2,1,2,
1,1|2|2-2|2|2,False,[],2,2,2,2,1,2,
2,0|1|2-1|1|2,True,"[""0|*|2-1|*|2""]",1,1,2,2,0,1,0|*|2-1|*|2
3,0|1|2-1|1|2,True,"[""0|*|2-1|*|2""]",1,1,2,2,0,1,0|*|2-1|*|2
4,0|1|3-1|1|3,False,[],1,1,3,3,0,1,
5,0|1|3-1|1|3,False,[],1,1,3,3,0,1,
6,0|0|1-2|0|1,False,[],0,0,1,1,0,2,
7,0|0|1-2|0|1,False,[],0,0,1,1,0,2,
8,0|0|0-1|0|0,False,[],0,0,0,0,0,1,
9,0|0|0-1|0|0,False,[],0,0,0,0,0,1,


In [175]:
synth_combinations_df['group_key'].unique()

array(['0|*|2-1|*|2', '1|0|*-0|0|*'], dtype=object)

In [176]:
import pandas as pd
import itertools
import random


def clean_pattern_value(value):
    if not value or value == '*':
        return '*'
    try:
        return int(value)
    except ValueError:
        return '*'


def parse_group_key(group_key):
    if not isinstance(group_key, str):
        raise ValueError("Group key must be a string")

    patterns = group_key.split('-')
    if len(patterns) != 2:
        raise ValueError("Group key must contain exactly one hyphen separator")

    left_pattern = [clean_pattern_value(x) for x in patterns[0].split('|')]
    right_pattern = [clean_pattern_value(x) for x in patterns[1].split('|')]

    return left_pattern, right_pattern


def calculate_pattern_flexibility(pattern):
    """Calculate how flexible a pattern is based on number of wildcards"""
    return sum(1 for x in pattern if x == '*')


def generate_possible_values(pattern, position_values):
    if len(pattern) != len(position_values):
        raise ValueError(f"Pattern length ({len(pattern)}) must match position_values length ({len(position_values)})")

    values_per_position = []
    for pos, possible_vals in zip(pattern, position_values):
        if pos == '*':
            values_per_position.append(possible_vals)
        else:
            if pos not in possible_vals:
                raise ValueError(f"Fixed value {pos} not in possible values {possible_vals}")
            values_per_position.append([pos])

    return list(itertools.product(*values_per_position))


def sample_combinations(combinations, target_size):
    """Sample combinations to match target size"""
    if len(combinations) <= target_size:
        return combinations
    return random.sample(combinations, target_size)


def create_couple_key(left_values, right_values):
    left_str = '|'.join(str(x) for x in left_values)
    right_str = '|'.join(str(x) for x in right_values)
    return f"{left_str}-{right_str}"


def generate_individuals_from_group_key(group_key, position_values, column_names, target_size=100):
    try:
        left_pattern, right_pattern = parse_group_key(group_key)

        pattern_length = len(position_values)
        if len(left_pattern) != pattern_length:
            raise ValueError(
                f"Left pattern length ({len(left_pattern)}) does not match position values length ({pattern_length})")
        if len(right_pattern) != pattern_length:
            raise ValueError(
                f"Right pattern length ({len(right_pattern)}) does not match position values length ({pattern_length})")
        if len(column_names) != pattern_length:
            raise ValueError(
                f"Column names length ({len(column_names)}) does not match position values length ({pattern_length})")

        left_combinations = generate_possible_values(left_pattern, position_values)
        right_combinations = generate_possible_values(right_pattern, position_values)

        # Calculate how many combinations we need from each side to achieve target size
        total_flexibility = (calculate_pattern_flexibility(left_pattern) +
                             calculate_pattern_flexibility(right_pattern))

        if total_flexibility == 0:
            # If no wildcards, use all combinations
            left_sampled = left_combinations
            right_sampled = right_combinations
        else:
            # Sample square root of target size from each side to achieve approximately target_size total combinations
            side_target = int(pow(target_size, 0.5)) + 1
            left_sampled = sample_combinations(left_combinations, side_target)
            right_sampled = sample_combinations(right_combinations, side_target)

        all_combinations = list(itertools.product(left_sampled, right_sampled))

        # Final sampling to exactly match target size if we have more combinations than needed
        if len(all_combinations) > target_size:
            all_combinations = random.sample(all_combinations, target_size)

        result_data = []
        for left_vals, right_vals in all_combinations:
            row_data = {
                **{f'{col}_1': val for col, val in zip(column_names, left_vals)},
                **{f'{col}_2': val for col, val in zip(column_names, right_vals)},
                'couple_key': create_couple_key(left_vals, right_vals),
                'group_key': group_key
            }
            result_data.append(row_data)

        result_df = pd.DataFrame(result_data)

        # Reorder columns
        left_cols = [f'{col}_1' for col in column_names]
        right_cols = [f'{col}_2' for col in column_names]
        result_df = result_df[left_cols + right_cols + ['couple_key', 'group_key']]

        return result_df

    except Exception as e:
        raise ValueError(f"Error processing group key '{group_key}': {str(e)}")


def extract_position_values(df, column_names):
    position_values = []
    for col in column_names:
        unique_values = sorted(df[col].unique().tolist())
        position_values.append(unique_values)
    return position_values


def prepare_group_key_generation(df):
    column_names = list(df.columns)
    position_values = extract_position_values(df, column_names)
    return position_values, df.columns


def process_multiple_group_keys(df, group_keys, target_size_per_group=100):
    position_values, column_names = prepare_group_key_generation(df)

    all_results = []
    for group_key in group_keys:
        try:
            result_df = generate_individuals_from_group_key(
                group_key,
                position_values,
                column_names,
                target_size_per_group
            )
            all_results.append(result_df)
        except Exception as e:
            print(f"Error processing group key '{group_key}': {str(e)}")
            continue

    if all_results:
        final_df = pd.concat(all_results, ignore_index=True)
        return final_df
    else:
        raise ValueError("No valid results generated from any group keys")

In [177]:
position_values, column_names = prepare_group_key_generation(df_synth[feature_cols])
new_synth_df = process_multiple_group_keys(df_synth[feature_cols], list(df_synth['group_key'].unique()),
                                           target_size_per_group=500)
new_synth_df

Unnamed: 0,Attr1_X_1,Attr2_X_1,Attr1_T_1,Attr1_X_2,Attr2_X_2,Attr1_T_2,couple_key,group_key
0,0,0,2,1,0,2,0|0|2-1|0|2,0|*|2-1|*|2
1,0,0,2,1,1,2,0|0|2-1|1|2,0|*|2-1|*|2
2,0,0,2,1,2,2,0|0|2-1|2|2,0|*|2-1|*|2
3,0,0,2,1,3,2,0|0|2-1|3|2,0|*|2-1|*|2
4,0,1,2,1,0,2,0|1|2-1|0|2,0|*|2-1|*|2
5,0,1,2,1,1,2,0|1|2-1|1|2,0|*|2-1|*|2
6,0,1,2,1,2,2,0|1|2-1|2|2,0|*|2-1|*|2
7,0,1,2,1,3,2,0|1|2-1|3|2,0|*|2-1|*|2
8,0,2,2,1,0,2,0|2|2-1|0|2,0|*|2-1|*|2
9,0,2,2,1,1,2,0|2|2-1|1|2,0|*|2-1|*|2


In [178]:
new_synth_df['group_key'].drop_duplicates().shape

(2,)

In [179]:
import pandas as pd
from itertools import product
import math


def generate_all_combinations(column_values, column_names=None, percentage=100):
    # Validate percentage
    if not 0 < percentage <= 100:
        raise ValueError("Percentage must be between 0 and 100")

    # Calculate total possible combinations
    total_combinations = math.prod(len(values) for values in column_values)

    # Calculate how many combinations to generate
    num_combinations = int(total_combinations * (percentage / 100))
    num_combinations = max(1, num_combinations)  # Ensure at least one combination

    # Create column names if not provided
    if column_names is None:
        column_names = [f'column_{i}' for i in range(len(column_values))]

    # Generate all combinations and randomly sample
    all_combinations = list(product(*column_values))
    selected_combinations = random.sample(all_combinations, num_combinations)

    # Convert to DataFrame
    df = pd.DataFrame(selected_combinations, columns=column_names)

    return df

In [180]:
import re
from functools import lru_cache


@lru_cache(maxsize=1024)
def compile_pattern(pattern_string: str, data_schema: str) -> re.Pattern:
    """Compile and cache regex pattern."""

    def _compile_pattern(pattern_string: str, data_schema: str):
        res_pattern = []
        for k, v in zip(pattern_string.split('|'), data_schema.split('|')):
            if k == '*':
                res_pattern.append(f"[{v}]")
            else:
                res_pattern.append(k)
        res_pattern = "\|".join(res_pattern)
        return res_pattern

    if '-' in pattern_string:
        result_pat = []
        for pat1, pat2 in zip(pattern_string.split('-'), data_schema.split('-')):
            result_pat.append(_compile_pattern(pat1, pat2))

        return re.compile('-'.join(result_pat))

    else:
        return re.compile(_compile_pattern(pattern_string, data_schema))


@lru_cache(maxsize=4096)
def matches_pattern(pattern_string: str, test_string: str, data_schema: str) -> bool:
    """Cached pattern matching."""
    pattern = compile_pattern(pattern_string, data_schema)
    return bool(pattern.match(test_string))


def is_couple_part_of_a_group(couple_key, group_key_list, res_pattern):
    res = []

    couple_key_elems = couple_key.split('-')
    if len(couple_key_elems) != 2:
        print(f"Warning: Unexpected couple key format: {couple_key}")
        return res

    opt1 = f"{couple_key_elems[0]}-{couple_key_elems[1]}"
    opt2 = f"{couple_key_elems[1]}-{couple_key_elems[0]}"

    grp_res_pattern = f"{res_pattern}-{res_pattern}"

    for grp_key in group_key_list:
        if matches_pattern(grp_key, opt1, grp_res_pattern) or matches_pattern(grp_key, opt2, grp_res_pattern):
            res.append(grp_key)
    return res


@lru_cache(maxsize=1024)
def cached_matches_pattern(pattern: str, key: str, data_schema: str) -> bool:
    return matches_pattern(pattern, key, data_schema)


def process_row(row):
    matching_groups = [
        group_key for group_key in df_synth['group_key'].unique() if
        cached_matches_pattern(group_key, row['couple_key'], df_synth['attr_possible_values'].iloc[0])
    ]

    return pd.Series({
        'matching_groups': matching_groups,
    })

In [181]:
column_values = [list(map(int, list(e))) for e in df_synth['attr_possible_values'].iloc[0].split('|')]
column_values = column_values + column_values
column_names = list(filter(lambda x: 'Attr' in x, new_synth_df.columns))
new_all_synth_df = generate_all_combinations(column_values, column_names=column_names)
mk_coupl_key = lambda \
        x: f"{'|'.join(list(map(str, x[:len(column_names) // 2])))}-{'|'.join(list(map(str, x[len(column_names) // 2:])))}"
new_all_synth_df['couple_key'] = new_all_synth_df.apply(mk_coupl_key, axis=1)
new_all_synth_df['group_key'] = new_all_synth_df.apply(process_row, axis=1)
new_all_synth_df = new_all_synth_df.explode('group_key').reset_index(drop=True)
new_all_synth_df

Unnamed: 0,Attr1_X_1,Attr2_X_1,Attr1_T_1,Attr1_X_2,Attr2_X_2,Attr1_T_2,couple_key,group_key
0,1,1,2,0,2,2,1|1|2-0|2|2,
1,0,1,0,0,0,0,0|1|0-0|0|0,
2,0,1,0,0,2,2,0|1|0-0|2|2,
3,0,0,2,1,1,0,0|0|2-1|1|0,0|*|2-1|*|2
4,1,1,0,1,1,2,1|1|0-1|1|2,
...,...,...,...,...,...,...,...,...
319,1,2,1,0,2,1,1|2|1-0|2|1,
320,1,1,1,0,1,2,1|1|1-0|1|2,
321,0,2,0,0,2,1,0|2|0-0|2|1,
322,0,2,2,1,1,1,0|2|2-1|1|1,0|*|2-1|*|2


In [182]:
synth_cols = list(filter(lambda x: 'Attr' in x, synth_combinations_df.columns))
synth_combinations_df = synth_combinations_df[synth_cols + ['group_key']].copy().drop_duplicates()
new_synth_combinations_df = new_synth_df[synth_cols + ['group_key']].copy().drop_duplicates()
new_all_synth_df = new_all_synth_df[synth_cols + ['group_key']].copy().drop_duplicates()
result_combinations_df = result_combinations_df[
    synth_cols + ['group_key', 'couple_key', 'is_part_of_group', 'matching_groups']].copy().drop_duplicates()

synth_combinations_df['type'] = 'original_synth'
synth_combinations_df['is_part_of_group'] = False
synth_combinations_df['matching_groups'] = ''
new_synth_combinations_df['type'] = 'new_synth'
new_synth_combinations_df['is_part_of_group'] = False
new_synth_combinations_df['matching_groups'] = ''
new_all_synth_df['type'] = 'new_synth'
new_all_synth_df['is_part_of_group'] = False
new_all_synth_df['matching_groups'] = ''
new_all_synth_df['couple_key'] = ''
result_combinations_df['type'] = 'search'


In [183]:
full_synth_df = pd.concat([synth_combinations_df, new_synth_combinations_df, new_all_synth_df], axis=0)

In [184]:
# full_group_keys = (new_synth_df['group_key'].value_counts() <= 100).to_frame()
# full_synth_df = full_synth_df[full_synth_df['group_key'].isin(full_group_keys[full_group_keys['count']].index.tolist())]

In [185]:
final_df = pd.concat([full_synth_df, result_combinations_df], axis=0)
final_df.drop_duplicates()

Unnamed: 0,Attr1_X_1,Attr2_X_1,Attr1_T_1,Attr1_X_2,Attr2_X_2,Attr1_T_2,group_key,type,is_part_of_group,matching_groups,couple_key
0,0,1,2,1,0,2,0|*|2-1|*|2,original_synth,False,,
11,0,1,2,1,1,2,0|*|2-1|*|2,original_synth,False,,
17,0,1,2,1,2,2,0|*|2-1|*|2,original_synth,False,,
108,0,1,2,1,3,2,0|*|2-1|*|2,original_synth,False,,
117,0,2,2,1,0,2,0|*|2-1|*|2,original_synth,False,,
...,...,...,...,...,...,...,...,...,...,...,...
45,0,0,1,1,1,1,1|0|*-0|0|*,search,True,"[""1|0|*-0|0|*""]",1|0|0-1|1|1
46,0,0,1,1,0,3,1|0|*-0|0|*,search,True,"[""1|0|*-0|0|*""]",1|0|0-3|1|0
47,0,2,0,1,1,0,,search,False,[],0|0|2-0|1|1
48,0,2,3,0,3,3,,search,False,[],3|0|2-3|0|3


In [186]:
ll = final_df['group_key'].value_counts()
print(ll)

group_key
1|0|*-0|0|*    291
0|*|2-1|*|2    105
Name: count, dtype: int64


In [187]:
# final_df[final_df['group_key'] == '0|1|0-1|0|0']

In [188]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.spatial import ConvexHull


def create_hull_boundary(points):
    """Create a convex hull boundary around points."""
    if len(points) < 3:  # Need at least 3 points for a convex hull
        return None

    try:
        hull = ConvexHull(points)
        # Get the vertices of the hull in order
        vertices = points[hull.vertices]
        # Close the polygon by adding the first point at the end
        vertices = np.vstack((vertices, vertices[0]))
        return vertices
    except Exception:
        return None


def plot_embedding_interactive(embedding, groups, types, is_part_of_group, couple_keys, matching_groups,
                               draw_boundaries=False):
    unique_groups = sorted(set(groups))
    n_colors = len(unique_groups)

    # Generate muted colors using HSL with lower saturation and lightness
    colors = [f'hsl({h},50%,70%)' for h in np.linspace(0, 360, n_colors)]
    color_dict = dict(zip(unique_groups, colors))

    # Create figure
    fig = go.Figure()

    # Draw group boundaries first (if enabled) so they appear behind points
    if draw_boundaries:
        for group in unique_groups:
            if group.lower() != 'unknown':  # Skip drawing boundary for unknown group
                # Include only new_synth and original_synth points
                group_mask = (groups == group) & ((types == 'new_synth') | (types == 'original_synth'))

                if np.sum(group_mask) >= 3:  # Need at least 3 points for a boundary
                    group_points = embedding[group_mask]
                    hull_vertices = create_hull_boundary(group_points)

                    if hull_vertices is not None:
                        # Draw the boundary
                        fig.add_trace(go.Scatter(
                            x=hull_vertices[:, 0],
                            y=hull_vertices[:, 1],
                            mode='lines',
                            line=dict(
                                color=color_dict[group],
                                width=2,
                                dash='dot'
                            ),
                            fill='toself',
                            fillcolor=color_dict[group].replace('70%', '95%'),  # Make fill very transparent
                            opacity=0.2,
                            name=f'Group {group} Boundary',
                            showlegend=False,
                            hoverinfo='skip'
                        ))

    # Add traces for new_synth points
    for group in unique_groups:
        mask = (types == 'new_synth') & (groups == group)
        if np.any(mask):
            fig.add_trace(go.Scatter(
                x=embedding[mask, 0],
                y=embedding[mask, 1],
                mode='markers',
                marker=dict(
                    color=color_dict[group],
                    symbol='triangle-up',
                    size=8,
                    opacity=0.6,
                    line=dict(color='rgba(50,50,50,0.2)', width=1)
                ),
                name=f'New Synth - Group {group[:20]}...' if len(group) > 20 else f'New Synth - Group {group}',
                legendgroup=group,
                hovertemplate='Group: %{text}<br>Type: New Synth<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>',
                text=[group] * np.sum(mask)
            ))

    # Add traces for original_synth points
    for group in unique_groups:
        mask = (types == 'original_synth') & (groups == group)
        if np.any(mask):
            fig.add_trace(go.Scatter(
                x=embedding[mask, 0],
                y=embedding[mask, 1],
                mode='markers',
                marker=dict(
                    color=color_dict[group],
                    symbol='circle',
                    size=8,
                    opacity=0.6,
                    line=dict(color='rgba(50,50,50,0.2)', width=1)
                ),
                name=f'Original Synth - Group {group[:20]}...' if len(
                    group) > 20 else f'Original Synth - Group {group}',
                legendgroup=group,
                hovertemplate='Group: %{text}<br>Type: Original Synth<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>',
                text=[group] * np.sum(mask)
            ))

    # Add search points not in group - red
    mask_search_false = (types == 'search') & (~is_part_of_group)
    if np.any(mask_search_false):
        hover_text = [
            f'Couple Key: {ck}<br>Status: Not In Group'
            for ck in couple_keys[mask_search_false]
        ]
        fig.add_trace(go.Scatter(
            x=embedding[mask_search_false, 0],
            y=embedding[mask_search_false, 1],
            mode='markers',
            marker=dict(
                color='rgba(255,0,0,0.8)',
                symbol='x',
                size=12,
                line=dict(width=2)
            ),
            name='Search (Not In Group)',
            hovertemplate='Type: Search<br>%{text}<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>',
            text=hover_text
        ))

    # Add search points in group - green
    mask_search_true = (types == 'search') & (is_part_of_group)
    if np.any(mask_search_true):
        hover_text = [
            f'Couple Key: {ck}<br>Matching Group: {mg}<br>Status: In Group'
            for ck, mg in zip(couple_keys[mask_search_true], matching_groups[mask_search_true])
        ]
        fig.add_trace(go.Scatter(
            x=embedding[mask_search_true, 0],
            y=embedding[mask_search_true, 1],
            mode='markers',
            marker=dict(
                color='rgba(0,255,0,0.8)',
                symbol='x',
                size=12,
                line=dict(width=2)
            ),
            name='Search (In Group)',
            hovertemplate='Type: Search<br>%{text}<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>',
            text=hover_text
        ))

    # Update layout
    fig.update_layout(
        title='UMAP Projection of Attributes',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        hovermode='closest',
        plot_bgcolor='white',
        legend=dict(
            title="Groups and Types",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.05,
            itemsizing='constant'
        ),
        margin=dict(r=200),
        showlegend=True
    )

    # Add grid for better readability
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.1)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.1)')

    return fig

In [189]:
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from umap.umap_ import UMAP


def create_knn_weighted_metric(X, groups, k=5, group_weight=0.5):
    # Fit KNN on the feature space
    knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
    knn.fit(X)

    # Pre-compute KNN indices for each point
    distances, indices = knn.kneighbors(X)

    def custom_metric(a, b, a_idx=None, b_idx=None):
        # Euclidean distance component
        euclidean_dist = np.sqrt(np.sum((a - b) ** 2))

        if a_idx is None or b_idx is None:
            return euclidean_dist

        # Group similarity component
        a_neighbors = indices[a_idx]
        b_neighbors = indices[b_idx]

        # Count shared group members in KNN neighborhoods
        a_neighbor_groups = groups[a_neighbors]
        b_neighbor_groups = groups[b_neighbors]

        a_group = groups[a_idx]
        b_group = groups[b_idx]

        # Calculate group similarity
        if a_group == b_group:
            group_sim = 0.0  # Minimum distance for same group
        else:
            # Count shared groups in neighborhoods
            shared_groups = len(set(a_neighbor_groups) & set(b_neighbor_groups))
            group_sim = 1.0 - (shared_groups / k)  # Normalize to 0-1

        # Combine distances with weighting
        combined_dist = ((1 - group_weight) * euclidean_dist +
                         group_weight * group_sim)

        return combined_dist

    return custom_metric


def prepare_data(df):
    # Select columns that start with 'Attr'
    attr_columns = [col for col in df.columns if col.startswith('Attr')]

    # Extract features
    X = df[attr_columns].values

    # Extract metadata and ensure groups are strings
    groups = df['group_key'].fillna('unknown').astype(str).values
    types = df['type'].values
    is_part_of_group = df['is_part_of_group'].values
    couple_keys = df['couple_key'].values
    matching_groups = df['matching_groups'].values

    return X, groups, types, is_part_of_group, attr_columns, couple_keys, matching_groups


X, groups, types, is_part_of_group, attr_columns, couple_keys, matching_groups = prepare_data(final_df)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

custom_metric = create_knn_weighted_metric(
    X_scaled,
    groups,
    k=10,  # Number of neighbors to consider
    group_weight=0.7  # Weight given to group similarity
)

# Create and fit UMAP
reducer = UMAP(
    n_neighbors=5,
    n_components=2,
    metric=custom_metric,
    output_metric='euclidean',
    n_epochs=1000,
    learning_rate=1.0,
    init='spectral',
    min_dist=0.01,
    spread=0.5,
    target_n_neighbors=10,
    target_weight=0.7,
    random_state=42
)

embedding = reducer.fit_transform(X_scaled)

ValueError: could not convert string to float: '0|*|2-1|*|2'

In [166]:
# Create and show the plot
fig = plot_embedding_interactive(
    embedding=embedding,
    groups=groups,
    types=types,
    is_part_of_group=is_part_of_group,
    couple_keys=couple_keys,
    matching_groups=matching_groups,
    draw_boundaries=True
)
fig.show()