In [None]:
import unicodedata
import re

def remove_accents(input_str):
    nfkd_form = unicodedata.normalize('NFKD', input_str)
    return u"".join([c for c in nfkd_form if not unicodedata.combining(c)])

def tokenize_file(filename):
    with open(filename, 'r', encoding='utf-8') as file:
        content = file.read().lower()
    # Remove accents
    content = remove_accents(content)
    # Split by comma or newline, then strip whitespace
    tokens = [token.strip() for token in re.split(r'[,\n]', content) if token.strip()]
    return tokens

# Read and tokenize both files
names = tokenize_file('name_segregated_tokens.txt')
places = tokenize_file('place_segregated_tokens.txt')

# Remove duplicates
names = list(set(names))
places = list(set(places))

# Move common tokens from places to names
common_tokens = set(names) & set(places)
names.extend(common_tokens)
places = [place for place in places if place not in common_tokens]

# Sort the lists
names.sort()
places.sort()

print(f"Number of unique names: {len(names)}")
print(f"Number of unique places: {len(places)}")

# Optional: Print the first few elements of each list to verify
print("\nFirst 10 names:")
print(names[:10])
print("\nFirst 10 places:")
print(places[:10])

In [None]:
import numpy as np
import pandas as pd
import re
from collections import Counter
import unicodedata
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

# Load the dataset
df = pd.read_csv('../unidades.csv')

# Display the first few rows to verify data
df.head()

# Define a function to remove accents
def remove_accents(input_str):
    nfkd_form = unicodedata.normalize('NFKD', input_str)
    return "".join([c for c in nfkd_form if not unicodedata.combining(c)])

# Define synonyms
synonyms = {
    'gab': 'gabinete',
    'gab.': 'gabinete',
    'presidencia': 'presidencia',
    'v': 'vara',
    'var': 'vara',
    'vio': 'violencia',
    'c': 'circunscricao',
    'juiza': 'juiz',
    'substituta': 'substituto',
    'dra': 'dr',
    'faz': 'fazenda',
    'fam': 'familia',
    'exma': 'exmo',
    'reg': 'registros',
    'pub': 'publico',
    'publ': 'publico',
    'publica': 'publico',
    'juv': 'juventude',
    'inf': 'infancia',
    'crim': 'criminal',
    'DEECRIM': 'criminal',
    'adj': 'adjunto',
    'cons': 'consumo',
    'jef': 'federal',
    'jud': 'judiciario',
    'desembargadora': 'desembargador',
    'desebargadora': 'desembargador',
    'desembargado': 'desembargador',
    'desembargador': 'desembargador',
    'desembargadores': 'desembargador',
    'des': 'desembargador',
    'desa': 'desembargador',
    'desemb': 'desembargador',
    'j': 'juizado',
    'jui': 'juizado',
    'civ': 'civel',
    'civeis': 'civel',
    'civil': 'civel',
    'crminal' : 'criminal',
    'esp': 'especial',
    'especiais': 'especial',
    'educativa': 'educacional',
    'contadoria/tesouraria': 'contadoria',
    'c/mulher': 'mulher',
    'calculos': 'calculo',
    'calc': 'calculo',
    'mulh': 'mulher',
    'adm': 'administracao',
    'amb': 'ambiental',
    'acomp': 'acompanhamento',
    'aten': 'atencao',
    'atend': 'atendimento',
    'aux': 'auxiliar',
    'aval': 'avaliacao',
    'compet': 'competencia',
    'conf': 'conflito',
    'confl': 'conflito',
    'coord': 'coordenacao',
    'cump': 'cumprimento',
    'def': 'defensoria',
    'dep': 'departamento',
    'dist': 'distribuicao',
    'distr': 'distribuicao',
    'exec': 'execucao',
    'fisc': 'fiscal',
    'gest': 'gestao',
    'inform': 'informacao',
    'inq': 'inquerito',
    'jurid': 'juridico',
    'med': 'mediacao',
    'mun': 'municipal',
    'munic': 'municipal',
    'org': 'organizacao',
    'pres': 'presidencia',
    'proc': 'processo',
    'prog': 'programa',
    'proj': 'projeto',
    'prot': 'protocolo',
    'rec': 'recurso',
    'rel': 'relator',
    'sec': 'secretaria',
    'serv': 'servico',
    'sist': 'sistema',
    'tec': 'tecnico',
    'trib': 'tribunal',
    'unid': 'unidade',
    'fg' : 'fig',
    'gm' : 'gmf',
    '[microrregiao' : 'microrregiao',
    'centr': 'centro',
    'turmas': 'turma',
    'tributaria': 'tributarios', 
    'trt9': 'trt',
    'acidentes': 'acidente',
    'administracao': 'administrativo',
    'administrativa': 'administrativo',
    'zonas' : 'zona',
    '*decima': 'decima',
    'acervos': 'acervo',
    'acordaos': 'acordao',
    'adjunta': 'adjunto',
    'administrativas': 'administrativo',
    'administrativos': 'administrativo',
    'adolescentes': 'adolescente',
    'alta vara': 'alto vara',
    'anexos': 'anexo',
    '\\gabinete' : 'gabinete',
    'detrabalho' : 'trabalho',
    'conflitos' : 'conflito',
    'cidanania' : 'cidadania',
    'arquivos' : 'arquivo',

    # Add more synonyms as needed
}

# Define multi-token replacements
multi_token_replacements = {
    'vt': ['vara', 'trabalho'],
    'cejusc' :['centro','judicial','solucao','conflitos','cidadania'],
    'tr': ['turma', 'recursal'],
    'jit': ['juizado', 'especial','civel'],
    'gades': ['gabinete', 'desembargador'],
    'jesp': ['juizado', 'especial','criminal'],
    'jec': ['juizado', 'especial','civel'],
    'saf': ['servico', 'anexo','fazendas'],
    'aos, autos' : ['aos', 'autos'],
    # Add more multi-token replacements as needed
}

# Function to replace synonyms and multi-token replacements
def replace_synonyms_and_multi_tokens(token):
    if token in multi_token_replacements:
        return multi_token_replacements[token]
    else:
        return [synonyms.get(token, token)]


# Tokenization function
def tokenize_name(name, additional_stopwords=None):
    name = remove_accents(name).lower()
    
    # REPLACE 'CJ' OR 'C J' WITH 'CIRCUNSCRICAO JUDICIAL'
    name = re.sub(r'\bc\s*j\b', 'circunscricao', name)
    
    # REMOVE NUMBERS AND NUMBER-LETTER COMBINATIONS, BUT KEEP THE PRECEDING WORD
    name = re.sub(r'\b(\d+\w*)\b', '', name)
    
    # COMBINE 'VICE' WITH THE FOLLOWING WORD (WITH SPACE OR HYPHEN) INTO A SINGLE TOKEN
    name = re.sub(r'\b(vice)[-\s]+(\w+)', r'\1_\2', name)

    
    tokens = re.split(r'\s|,|\.|\(|\)|-|\{|\}|\[|\]|\'|\"', name)
    tokens = [token.strip() for token in tokens if token.strip()]

    stopwords = ['de', 'da', 'do', 'das', 'dos', 
                 'e', 'a', 'o', 'i', 'u', 'b', 'as', 'ao',
                 '"', 'em', 'des', 'com', 'n', 'g', 'ap', 'sr', 'sra','/', '\\' ,'?', '\'', '\\\'', '\"gabinete', 'ou',
                'hora', 'solteira', 'villa','zz', '°', '¿',
                'i','ii','iii','iv','v','vi','vii','viii','ix','x',
                'xi','xii','xiii','xiv','xv','xvi','xvii','xviii','xix','xx',
                'xxi','xxii','xxiii','xxiv','xxv','xxvi','xxvii','xxviii','xxix','xxx',
                'xxxi','xxxii','xxxiii','xxxiv','xxxv','xxxvi','xxxvii','xxxviii','xxxix','xl',
                'xli','xlii','xliii','xliv','xlv','xlvi','xlvii','xlviii','xlix','l',
                'li','lii','liii','liv','lv','lvi','lvii','lviii','lix','lx',
                'lxi','lxii','lxiii','lxiv','lxv','lxvi','lxvii','lxviii','lxix','lxx',
                'lxxi','lxxii','lxxiii','lxxiv','lxxv','lxxvi','lxxvii','lxxviii','lxxix','lxxx',
                'lxxxi','lxxxii','lxxxiii','lxxxiv','lxxxv','lxxxvi','lxxxvii','lxxxviii','lxxxix','xc',
                'xci','xcii','xciii','xciv','xcv','xcvi','xcvii','xcviii','xcix','c','sanclerlandia','goianapolis','?',':','ci','cii','varao']
    if additional_stopwords:
        stopwords.extend(additional_stopwords)
    

    def combine_words(name, stopwords):
        # Define the words to be combined
        words_to_combine = [
            'sao', 'santa', 'santo', 'nova', 'novo', 'bom', 'boa', 
            'alto', 'alta', 'baixo', 'baixa', 'porto', 'campos', 
            'rio', 'foz', 'barra', 'passa', 'entre'
        ]

        # Function to replace word and its following non-stopword
        def replace_word(match):
            word1 = match.group(1)
            word2 = match.group(2)
            if word2.lower() not in stopwords:
                return f'{word1}_{word2}'
            return f'{word1} {word2}'

        # Create a combined regex pattern for all words to be combined
        pattern = r'\b(' + '|'.join(words_to_combine) + r')[\s-]+(\w+)'

        # Apply the replacement
        name = re.sub(pattern, replace_word, name)

        return name
    
    name = combine_words(name, stopwords)

    stopwords.extend([name.lower() for name in names])

    if additional_stopwords:
        stopwords.extend(additional_stopwords)

    tokens = re.split(r'\s|,|\.|\(|\)|-', name)
    tokens = [token.strip() for token in tokens if token.strip() and token not in stopwords]

    
    # PROCESS EACH TOKEN, APPLYING SYNONYMS AND MULTI-TOKEN REPLACEMENTS
    processed_tokens = []
    skip_next = False
    for i, token in enumerate(tokens):
        if skip_next:
            skip_next = False
            continue
        if token and token not in stopwords:
            # HANDLE CASE WHERE 'C' IS FOLLOWED BY 'J'
            if token == 'c' and i + 1 < len(tokens) and tokens[i + 1] == 'j':
                processed_tokens.append('circunscricao')
                skip_next = True
            else:
                processed_tokens.extend(replace_synonyms_and_multi_tokens(token))
    
    # REPLACE UNDERSCORES WITH SPACES IN PRESERVED CONJOINED EXPRESSIONS
    processed_tokens = [token.replace('_', ' ') for token in processed_tokens]
    return processed_tokens

# Apply tokenization and store original names
df['tokens'] = df['nomeUnidade'].apply(lambda x: tokenize_name(x))
df['original_name'] = df['nomeUnidade']

# Flatten list of tokens
all_tokens = [token for sublist in df['tokens'] for token in sublist]

# Calculate frequency of each token
token_counts = Counter(all_tokens)
common_tokens = token_counts.most_common()

# Convert to DataFrame for visualization
token_df = pd.DataFrame(common_tokens, columns=['token', 'count'])

print(f"Number of disparate tokens: {len(set(all_tokens))}")

# Display the DataFrame
token_df.head(2000)

In [None]:
import Levenshtein
from collections import defaultdict, Counter

def levenshtein_distance(s1, s2):
    return Levenshtein.distance(s1, s2)

def find_similar_tokens(tokens, threshold=0.8):
    similar_tokens = defaultdict(set)
    for i, token1 in enumerate(tokens):
        for token2 in tokens[i+1:]:
            if token1 != token2:
                max_len = max(len(token1), len(token2))
                similarity = 1 - (levenshtein_distance(token1, token2) / max_len)
                if similarity >= threshold:
                    similar_tokens[token1].add(token2)
                    similar_tokens[token2].add(token1)
    return similar_tokens

def group_similar_tokens(similar_tokens):
    groups = []
    processed_tokens = set()

    for token, similar in similar_tokens.items():
        if token not in processed_tokens:
            group = {token} | similar
            groups.append(group)
            processed_tokens |= group

    return groups

def consolidate_tokens(groups, token_counts):
    token_mapping = {}
    for group in groups:
        most_common = max(group, key=lambda t: token_counts[t])
        for token in group:
            token_mapping[token] = most_common
    return token_mapping

def apply_token_mapping(tokens, token_mapping):
    return [token_mapping.get(token, token) for token in tokens]

# Use existing token counts
token_counts = token_df['token'].value_counts()

# Get unique tokens for similarity comparison
unique_tokens = token_df['token'].unique().tolist()

# Find similar tokens among all unique tokens
similar_tokens = find_similar_tokens(unique_tokens)

# Group similar tokens
grouped_tokens = group_similar_tokens(similar_tokens)

# Consolidate tokens
token_mapping = consolidate_tokens(grouped_tokens, token_counts)

# Store original tokens
df['original_tokens'] = df['tokens']

# Replace original tokens with consolidated tokens
df['tokens'] = df['tokens'].apply(lambda tokens: apply_token_mapping(tokens, token_mapping))

# Update token_df with consolidated tokens
consolidated_token_counts = Counter([token for tokens in df['tokens'] for token in tokens])
token_df = pd.DataFrame(list(consolidated_token_counts.items()), columns=['token', 'count'])
token_df = token_df.sort_values('count', ascending=False).reset_index(drop=True)

# Print the results
print("Token Consolidation Results:")
for group in grouped_tokens:
    most_common = max(group, key=lambda t: token_counts[t])
    print(f"\nGroup consolidated to '{most_common}' ({consolidated_token_counts[most_common]}):")
    for token in sorted(group):
        if token != most_common:
            print(f"  - {token} ({token_counts[token]}) -> {most_common}")

# Print statistics
print(f"\nTotal number of unique original tokens: {len(unique_tokens)}")
print(f"Total number of unique consolidated tokens: {len(set(token_mapping.values()))}")
print(f"Number of tokens consolidated: {len(token_mapping) - len(set(token_mapping.values()))}")

# Sample of original vs consolidated tokens
print("\nSample of original vs consolidated tokens:")
for _, row in df.head().iterrows():
    print(f"\nOriginal: {row['original_tokens']}")
    print(f"Consolidated: {row['tokens']}")

print("\nNote: The original tokens are now stored in the 'original_tokens' column.")
print("The 'tokens' column and 'token_df' now contain the consolidated tokens.")
print("All downstream processes will now use the consolidated tokens automatically.")

In [None]:
import numpy as np
import pandas as pd
from typing import List, Tuple
from collections import Counter

def get_frequent_tokens(df: pd.DataFrame, min_frequency: int = 2) -> set:
    all_tokens = [token for sublist in df['tokens'] for token in sublist]
    token_counts = Counter(all_tokens)
    return {token for token, count in token_counts.items() if count >= min_frequency}

def initialize_entries(df: pd.DataFrame, places: List[str]) -> List[np.ndarray]:
    """
    Convert DataFrame rows to Entry format with normalization and 'alien' handling,
    considering only tokens with frequency >= 9.
    Place tokens are replaced with '[cidade]' and always at the end of the entry.
    """
    frequent_tokens = get_frequent_tokens(df)
    max_tokens = df['tokens'].apply(lambda x: len([token for token in x if token in frequent_tokens])).max()
    
    entries = []
    for index, row in df.iterrows():
        filtered_tokens = []
        cidade_token = None
        for token in row['tokens']:
            if token in places:
                cidade_token = '[cidade]'
            elif token in frequent_tokens:
                filtered_tokens.append(token)
        
        # Add '[cidade]' token at the end if it exists
        if cidade_token:
            filtered_tokens.append(cidade_token)
        
        # Normalize the number of tokens
        normalized_tokens = filtered_tokens + ['null'] * (max_tokens - len(filtered_tokens))
        
        # Check if all tokens are null, replace first with 'alien' if so
        if all(token == 'null' for token in normalized_tokens):
            normalized_tokens[0] = 'alien'
        
        entry = np.array([index, np.array(normalized_tokens, dtype=object)], dtype=object)
        entries.append(entry)
    
    return entries

class ClassificationStructure:
    def __init__(self):
        self.classifications = {}  # Dictionary to store classifications and their entry indices
        self.entries = []  # List to store all entries
        self.entry_classifications = {}  # Dictionary to store entry indices and their classifications
        self.weights = {}  # Dictionary to store weights of classifications

    def add_entry(self, entry: np.ndarray, classifications: List[str]):
        entry_index = len(self.entries)
        self.entries.append(entry)
        self.entry_classifications[entry_index] = set()
        
        for classification in classifications:
            if classification and classification != 'null':
                if classification not in self.classifications:
                    self.classifications[classification] = set()
                    self.weights[classification] = 0
                
                self.classifications[classification].add(entry_index)
                self.entry_classifications[entry_index].add(classification)
                self.weights[classification] += 1

    def remove_entry(self, entry_index: int):
        if entry_index in self.entry_classifications:
            for classification in self.entry_classifications[entry_index]:
                self.classifications[classification].remove(entry_index)
                self.weights[classification] -= 1
                if len(self.classifications[classification]) == 0:
                    del self.classifications[classification]
                    del self.weights[classification]
            del self.entry_classifications[entry_index]
            self.entries[entry_index] = None

    def get_entries_with_classifications(self, classifications: List[str]) -> List[np.ndarray]:
        if not classifications:
            return []
        valid_sets = [self.classifications[c] for c in classifications if c in self.classifications]
        if not valid_sets:
            return []
        entry_indices = set.intersection(*valid_sets)
        return [self.entries[i] for i in entry_indices if self.entries[i] is not None]

    def get_classifications_for_entry(self, entry_index: int) -> set:
        return self.entry_classifications.get(entry_index, set())

    def get_weight(self, classification: str) -> int:
        return self.weights.get(classification, 0)

    def get_overlaps(self) -> dict:
        overlaps = {}
        for entry_index, classifications in self.entry_classifications.items():
            if len(classifications) > 1:
                overlap_key = frozenset(classifications)
                if overlap_key not in overlaps:
                    overlaps[overlap_key] = set()
                overlaps[overlap_key].add(entry_index)
        return overlaps

In [None]:
# Example usage (assuming df is your original DataFrame)
cs = ClassificationStructure()

# Create Entry representations
entries = initialize_entries(df, places)

print(f"Total number of entries: {len(entries)}")
print("Sample entry:", entries[0])

# Populate ClassificationStructure
for entry in entries:
    index, tokens = entry
    cs.add_entry(entry, tokens)

# Example queries
print("\nEntries with 'vara_alpha' (if exists):", cs.get_entries_with_classifications(['vara']))
print("\nClassifications for entry 0:", cs.get_classifications_for_entry(0))
print("\nWeight of 'juizado_beta' (if exists):", cs.get_weight('juizado_beta'))
print("\nEntries with 'sao paulo' (if exists):", cs.get_entries_with_classifications(['sao paulo']))

In [None]:
import numpy as np
from collections import defaultdict
from typing import List, Tuple

def hierarchical_sort(entries: List[np.ndarray]) -> List[np.ndarray]:
    def sort_level(level_entries: List[np.ndarray], token_index: int) -> List[np.ndarray]:
        if not level_entries or token_index >= len(level_entries[0][1]):
            return level_entries

        # Group entries by token at the current level
        groups = defaultdict(list)
        for entry in level_entries:
            token = entry[1][token_index]
            groups[token].append(entry)

        # Sort groups by frequency, then alphabetically
        sorted_groups = sorted(groups.items(), 
                               key=lambda x: (-len(x[1]), x[0] if x[0] != 'null' else 'zzz'))

        # Recursively sort each group
        sorted_entries = []
        for _, group in sorted_groups:
            sorted_group = sort_level(group, token_index + 1)
            sorted_entries.extend(sorted_group)

        return sorted_entries

    return sort_level(entries, 0)

# Assuming 'entries' is your list of entry arrays
sorted_entries = hierarchical_sort(entries)

# Print the first few sorted entries to verify
for entry in sorted_entries:
    print(f"Index: {entry[0]}, Tokens: {entry[1]}")

In [None]:
def create_token_relationship_map(entries: List[np.ndarray]) -> pd.DataFrame:
    # Collect all unique tokens
    all_tokens = set()
    for entry in entries:
        all_tokens.update([token for token in entry[1] if token != 'null'])
    all_tokens = sorted(list(all_tokens))

    # Create a dictionary to store token relationships
    token_relationships = {token: defaultdict(int) for token in all_tokens}

    # Analyze entries and update relationships
    for entry in entries:
        tokens = [token for token in entry[1] if token != 'null']
        for i, token1 in enumerate(tokens):
            for token2 in tokens[i+1:]:
                token_relationships[token1][token2] += 1
                token_relationships[token2][token1] += 1

    # Create a DataFrame from the relationships
    df = pd.DataFrame(index=all_tokens, columns=all_tokens)
    
    for token1 in all_tokens:
        total = sum(token_relationships[token1].values())
        for token2 in all_tokens:
            if token1 == token2:
                df.at[token1, token2] = 1.0  # Diagonal
            elif token2 in token_relationships[token1]:
                count = token_relationships[token1][token2]
                df.at[token1, token2] = count / total
            else:
                df.at[token1, token2] = 0.0

    return df

In [None]:
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from typing import List, Dict, Tuple
from collections import defaultdict
import colorspacious
import matplotlib.pyplot as plt
import warnings
from scipy.spatial.distance import cdist
from scipy.spatial import distance_matrix
from scipy.spatial import ConvexHull
import colorsys

def create_token_relationship_map(entries: List[np.ndarray]) -> pd.DataFrame:
    # Collect all unique tokens
    all_tokens = set()
    for entry in entries:
        all_tokens.update([token for token in entry[1] if token != 'null'])
    all_tokens = sorted(list(all_tokens))

    # Create a dictionary to store token relationships
    token_relationships = {token: defaultdict(int) for token in all_tokens}

    # Analyze entries and update relationships
    for entry in entries:
        tokens = [token for token in entry[1] if token != 'null']
        for i, token1 in enumerate(tokens):
            for token2 in tokens[i+1:]:
                token_relationships[token1][token2] += 1
                token_relationships[token2][token1] += 1

    # Create a DataFrame from the relationships
    df = pd.DataFrame(index=all_tokens, columns=all_tokens)
    
    for token1 in all_tokens:
        total = sum(token_relationships[token1].values())
        for token2 in all_tokens:
            if token1 == token2:
                df.at[token1, token2] = 1.0  # Diagonal
            elif token2 in token_relationships[token1]:
                count = token_relationships[token1][token2]
                df.at[token1, token2] = count / total
            else:
                df.at[token1, token2] = 0.0

    return df

# Color conversion functions
def lab_to_rgb(lab_color):
    rgb = colorspacious.cspace_convert(lab_color, "CAM02-UCS", "sRGB1")
    return np.clip(rgb, 0, 1)  # Ensure RGB values are between 0 and 1


def rgb_to_hex(rgb_color):
    return '#{:02x}{:02x}{:02x}'.format(int(rgb_color[0]*255), int(rgb_color[1]*255), int(rgb_color[2]*255))

def rgb_to_hsv(rgb_color):
    return colorsys.rgb_to_hsv(*rgb_color)

def hsv_to_rgb(hsv_color):
    return colorsys.hsv_to_rgb(*hsv_color)

def normalize_hsv_colors(hsv_colors: List[Tuple[float, float, float]]) -> List[Tuple[float, float, float]]:
    h_values, s_values, v_values = zip(*hsv_colors)
    
    # Find the maximum S and V values
    s_max, v_max = max(s_values), max(v_values)
    s_min, v_min = min(s_values), min(v_values)
    
    # Normalize S and V values, but maintain some spread
    normalized_colors = [
        (h, 
         0.2 + 0.8 * (s - s_min) / (s_max - s_min) if s_max > s_min else s, 
         0.2 + 0.8 * (v - v_min) / (v_max - v_min) if v_max > v_min else v)
        for h, s, v in hsv_colors
    ]
    
    return normalized_colors

def reduce_dimensions(relationship_matrix: np.ndarray) -> np.ndarray:
    tsne = TSNE(n_components=2, random_state=42)
    return tsne.fit_transform(relationship_matrix)

def cluster_tokens(reduced_data: np.ndarray, n_clusters: int = 4) -> Tuple[np.ndarray, np.ndarray]:
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(reduced_data)
    cluster_centers = kmeans.cluster_centers_
    return clusters, cluster_centers

def calculate_cluster_proximities(reduced_data: np.ndarray, cluster_centers: np.ndarray) -> np.ndarray:
    distances = cdist(reduced_data, cluster_centers)
    max_distance = np.max(distances)
    return 1 - (distances / max_distance)


def calculate_saturation_factor(distances, cluster_labels, i):
    assigned_cluster = cluster_labels[i]
    assigned_distance = distances[i, assigned_cluster]
    
    # Find the distance to the next closest cluster
    other_distances = [dist for j, dist in enumerate(distances[i]) if j != assigned_cluster]
    next_closest_distance = min(other_distances)
    
    # Calculate how "decided" this token's cluster assignment is
    decisiveness = (next_closest_distance - assigned_distance) / next_closest_distance
    
    # Adjust saturation to highlight both cluster centers and boundary regions
    if decisiveness > 0.5:
        # For tokens firmly in a cluster, reduce saturation for tokens far from center
        return 1.0 - (assigned_distance / np.max(distances[:, assigned_cluster])) * 0.5
    else:
        # For boundary tokens, increase saturation
        return 1.0 + (0.5 - decisiveness) * 0.5

def assign_token_colors(entries: List[np.ndarray], n_clusters: int = 4) -> Dict[str, Tuple[float, float, float]]:
    relationship_matrix = create_token_relationship_map(entries)
    reduced_data = reduce_dimensions(relationship_matrix)
    
    # Perform clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(reduced_data)
    cluster_centers = kmeans.cluster_centers_
    
    # Define base colors in CIELAB space (as float64)
    base_colors = np.array([
        [90.0, -150.0, 0.0],    # Green
        [90.0, 150.0, 0.0],     # Red
        [90.0, 0.0, -150.0],    # Blue
        [90.0, 0.0, 150.0]      # Yellow
    ], dtype=np.float64)
    
    token_colors = {}
    tokens = relationship_matrix.index.tolist()
    
    # Calculate distances to cluster centers
    distances = distance_matrix(reduced_data, cluster_centers)
    
    for i, token in enumerate(tokens):
        # Initialize color as float64
        color = np.zeros(3, dtype=np.float64)
        
        for j in range(n_clusters):
            # Calculate influence, ensuring float64 output
            influence = 1.0 / (distances[i, j] ** 2)
            color += base_colors[j] * influence
        
        # Normalize color
        norm = np.linalg.norm(color)
        if norm != 0:
            color /= norm
        
        # Calculate and apply the new saturation factor
        saturation_factor = calculate_saturation_factor(distances, cluster_labels, i)
        color[1:] *= saturation_factor  # Adjust a* and b* components
        
        token_colors[token] = tuple(color)
    
    # Assign white color to 'null' token
    token_colors['null'] = (100.0, 0.0, 0.0)  # White in CIELAB
    
    return token_colors, reduced_data, cluster_labels

def blend_entry_colors(entry: np.ndarray, token_colors: Dict[str, Tuple[float, float, float]]) -> Tuple[float, float, float]:
    colors = [token_colors.get(token, token_colors['null']) for token in entry[1]]
    
    if not colors or all(color == token_colors['null'] for color in colors):
        return (100, 0, 0)  # White for empty entries or entries with only null tokens
    
    # Filter out null colors for blending
    non_null_colors = [color for color in colors if color != token_colors['null']]
    
    if not non_null_colors:
        return (100, 0, 0)  # White if all tokens were null
    
    # Calculate weights based on position
    num_tokens = len(non_null_colors)
    weights = np.linspace(1, 0.5, num_tokens)  # Linear decrease from 1 to 0.5
    weights = weights / np.sum(weights)  # Normalize weights to sum to 1
    
    # Calculate the weighted average for each component
    l_blend = np.average([c[0] for c in non_null_colors], weights=weights)
    a_blend = np.average([c[1] for c in non_null_colors], weights=weights)
    b_blend = np.average([c[2] for c in non_null_colors], weights=weights)
    
    # Darken based on the number of non-null tokens
    darkness_factor = min(num_tokens * 0.05, 0.3)  # Max 30% darkness
    l_blend *= (1 - darkness_factor)
    
    # Increase saturation
    saturation_factor = 1 + (num_tokens * 0.05)  # Increase saturation based on non-null tokens
    a_blend *= saturation_factor
    b_blend *= saturation_factor
    
    return (l_blend, a_blend, b_blend)

def visualize_token_clusters(reduced_data: np.ndarray, clusters: np.ndarray, tokens: List[str]):
    plt.figure(figsize=(12, 8))
    scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=clusters, cmap='viridis')
    plt.colorbar(scatter)
    
    # Ensure the number of annotations is within the bounds of the reduced_data array
    num_points = min(len(reduced_data), len(tokens))
    
    for i in range(0, num_points, max(1, num_points // 10)):  # Annotate about 10 points
        plt.annotate(tokens[i], (reduced_data[i, 0], reduced_data[i, 1]))
    
    plt.title('Token Clusters in 2D Space')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.tight_layout()
    plt.show()

def visualize_token_colors(reduced_data: np.ndarray, token_colors: Dict[str, Tuple[float, float, float]], cluster_labels: np.ndarray):
    plt.figure(figsize=(12, 10))
    
    # Convert LAB colors to RGB for plotting
    rgb_colors = [lab_to_rgb(np.array(color, dtype=np.float64)) for token, color in token_colors.items() if token != 'null']
    valid_data = reduced_data[[i for i, (token, _) in enumerate(token_colors.items()) if token != 'null']]
    valid_labels = cluster_labels[[i for i, (token, _) in enumerate(token_colors.items()) if token != 'null']]
    
    # Create scatter plot
    scatter = plt.scatter(valid_data[:, 0], valid_data[:, 1], c=rgb_colors, s=50, alpha=0.7)
    
    # Customize the plot
    plt.title('Token Color Distribution in Clusters', fontsize=16)
    plt.xlabel('Dimension 1', fontsize=14)
    plt.ylabel('Dimension 2', fontsize=14)
    
    # Remove top and right spines
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    # Add a light grid
    plt.grid(True, linestyle='--', alpha=0.3)
    
    # Adjust tick label size
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    # Add colorbar
    cbar = plt.colorbar(scatter, label='Color', aspect=30)
    cbar.ax.tick_params(labelsize=10)
    
    # Add cluster boundaries
    for cluster in range(4):
        cluster_points = valid_data[valid_labels == cluster]
        if len(cluster_points) > 2:  # Need at least 3 points to create a convex hull
            hull = ConvexHull(cluster_points)
            for simplex in hull.simplices:
                plt.plot(cluster_points[simplex, 0], cluster_points[simplex, 1], 'k-', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def visualize_color_distribution(token_colors: Dict[str, Tuple[float, float, float]]):
    colors_lab = np.array(list(token_colors.values()))
    colors_rgb = np.array([lab_to_rgb(color) for color in colors_lab])
    
    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    
    scatter1 = ax1.scatter(colors_lab[:, 1], colors_lab[:, 2], c=colors_rgb)
    ax1.set_xlabel('a* (Green-Red)')
    ax1.set_ylabel('b* (Blue-Yellow)')
    ax1.set_title('Token Color Distribution in a*b* plane')
    ax1.set_xlim(-128, 128)
    ax1.set_ylim(-128, 128)
    fig.colorbar(scatter1, ax=ax1, label='L* value')
    
    scatter2 = ax2.scatter(colors_lab[:, 1], colors_lab[:, 0], c=colors_rgb)
    ax2.set_xlabel('a* (Green-Red)')
    ax2.set_ylabel('L* (Lightness)')
    ax2.set_title('Token Color Distribution in L*a* plane')
    ax2.set_xlim(-128, 128)
    ax2.set_ylim(0, 100)
    fig.colorbar(scatter2, ax=ax2, label='b* value')
    
    plt.tight_layout()
    plt.show()

def visualize_entry_colors(entries: List[np.ndarray], token_colors: Dict[str, Tuple[float, float, float]]):
    # Blend colors for each entry
    entry_colors_lab = [blend_entry_colors(entry, token_colors) for entry in entries]
    entry_colors_rgb = [lab_to_rgb(color) for color in entry_colors_lab]
    
    # Reduce dimensions of entries for visualization
    entry_matrix = np.array([np.bincount(entry[1] != 'null', minlength=len(token_colors)) for entry in entries])
    tsne = TSNE(n_components=2, random_state=42)
    reduced_entries = tsne.fit_transform(entry_matrix)
    
    plt.figure(figsize=(12, 8))
    scatter = plt.scatter(reduced_entries[:, 0], reduced_entries[:, 1], c=entry_colors_rgb)
    
    plt.title('Entry Colors in 2D Space')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.colorbar(scatter, label='Color')
    plt.tight_layout()
    plt.show()


def visualize_color_mosaic(sorted_entries: List[Tuple[np.ndarray, Tuple[float, float, float], str]], output_file: str = 'color_mosaic.png'):
    n = len(sorted_entries)
    grid_size = int(np.ceil(np.sqrt(n)))
    
    fig, ax = plt.subplots(figsize=(20, 20))
    for i, (entry, _, hex_color) in enumerate(sorted_entries):
        row = i // grid_size
        col = i % grid_size
        ax.add_patch(plt.Rectangle((col, grid_size - row - 1), 1, 1, facecolor=hex_color))
    
    ax.set_xlim(0, grid_size)
    ax.set_ylim(0, grid_size)
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Color mosaic saved as {output_file}")

def sort_entries_by_color(entries: List[np.ndarray], token_colors: Dict[str, Tuple[float, float, float]]) -> List[Tuple[np.ndarray, Tuple[float, float, float], str]]:
    entry_colors = []
    for entry in entries:
        lab_color = blend_entry_colors(entry, token_colors)
        rgb_color = lab_to_rgb(lab_color)
        hsv_color = rgb_to_hsv(rgb_color)
        entry_colors.append((entry, rgb_color, hsv_color))
    
    # Normalize HSV colors
    _, _, hsv_colors = zip(*entry_colors)
    normalized_hsv_colors = normalize_hsv_colors(hsv_colors)
    
    # Create new list with normalized colors
    normalized_entry_colors = []
    for (entry, _, _), normalized_hsv in zip(entry_colors, normalized_hsv_colors):
        if any(np.isnan(normalized_hsv)):
            print(f"Found NaN in HSV values for entry: {entry}")
            continue  # Skip this entry if NaN is found
        
        normalized_rgb = hsv_to_rgb(normalized_hsv)
        hex_color = rgb_to_hex(normalized_rgb)
        normalized_entry_colors.append((entry, normalized_rgb, hex_color, normalized_hsv))
    
    # Sort by hue, then saturation, then value
    sorted_entries = sorted(normalized_entry_colors, key=lambda x: x[3])
    
    # Remove HSV color from the returned list
    return [(entry, rgb, hex) for entry, rgb, hex, _ in sorted_entries]


# Assuming 'entries' is your list of processed entry arrays
token_colors, reduced_data, cluster_labels = assign_token_colors(entries)

# Sort entries by color and create mosaic
sorted_entries = sort_entries_by_color(entries, token_colors)

# Add a color distribution analysis
print("\nColor Distribution Analysis:")
color_counts = defaultdict(int)
for _, rgb_color, _ in sorted_entries:
    hsv_color = rgb_to_hsv(rgb_color)
    hue_category = int(hsv_color[0] * 4)  # Divide hue into 4 categories
    color_counts[hue_category] += 1

for category, count in color_counts.items():
    print(f"Color category {category}: {count} entries")

# Visualizations
# Call the visualization functions
# Check for any mismatches
mismatched_tokens = set(tokens) - set(token_colors.keys())
if mismatched_tokens:
    print(f"Tokens in 'tokens' but not in 'token_colors': {mismatched_tokens}")

# Visualize only if data is consistent
visualize_token_colors(reduced_data, token_colors, cluster_labels)
visualize_entry_colors(entries, token_colors)
visualize_color_mosaic(sorted_entries, 'normalized_entry_color_mosaic_prototype.png')

# Print example results
print("\nExample Sorted Entries (first 10):")
for entry, rgb_color, hex_color in sorted_entries[:10]:
    hsv_color = rgb_to_hsv(rgb_color)
    print(f"Entry: {entry[1]}, RGB: {rgb_color}, Hex: {hex_color}, HSV: {hsv_color}")

# Print example Token Colors
print("\nExample Token Colors (in CIELAB space):")
for token, color in list(token_colors.items())[:10]:
    print(f"{token}: LAB{color}")

print("\nExample Token Colors (in sRGB space):")
for token, lab_color in list(token_colors.items())[:10]:
    rgb_color = lab_to_rgb(lab_color)
    print(f"{token}: RGB{rgb_color}")

print("\nExample Sorted Entries (first 10):")
for entry, rgb_color, hex_color in sorted_entries[:10]:
    print(f"Entry: {entry[1]}, RGB: {rgb_color}, Hex: {hex_color}")