In [9]:
import pandas as pd
import os
from datetime import datetime, timedelta
import glob
import pandas as pd
import random
import numpy as np
import seaborn as sns
import ast
import json
import os
import re
import itertools
import folium
import hdbscan
import h3

from scipy.spatial.transform import Rotation as R
from scipy.spatial.distance import euclidean
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import matplotlib.patches as patches
from itertools import combinations
from collections import defaultdict
from collections import Counter
from typing import Dict, Set, List, Optional, Tuple

# Data Preparation

In [2]:
df_gps = pd.read_csv("data/processed_trips.csv")

df_people = pd.read_csv("data/individuals_dataset.csv")
df_people = df_people[df_people['GPS_RECORD'] == True]

df_merged = pd.merge(
    df_gps,
    df_people[['ID', 'WEIGHT_INDIV']],
    left_on='ID',
    right_on='ID',
    how='inner'
)#.drop(columns='ID')
df_merged = df_merged.rename(columns={"ori_lat": "start_lat", "ori_lon": "start_lon", "dst_lat":'end_lat',"dst_lon":'end_lon'})

In [3]:
h3_resolution = 10
df_merged['start_h3'] = df_merged.apply(lambda row: h3.latlng_to_cell(row['start_lat'], row['start_lon'], h3_resolution), axis=1)
df_merged['end_h3'] = df_merged.apply(lambda row: h3.latlng_to_cell(row['end_lat'], row['end_lon'], h3_resolution), axis=1)

In [4]:
od_matrix_first = df_merged.groupby(['start_h3', 'end_h3']).agg({
    'WEIGHT_INDIV': ['sum', 'count']
}).reset_index()

od_matrix_first.columns = ['start_h3', 'end_h3', 'total_weight', 'count']

In [5]:
#Keep only OD pairs within central paris area

# Single hexagon for paris (res 5): "851fb467fffffff"
# 4 children hexagon covering central Paris (res=6): "861fb4667ffffff", "861fb4677ffffff", "861fb466fffffff", "861fb4647ffffff"
# res=7 "871fb4674ffffff", "871fb475bffffff", "871fb4675ffffff", "871fb4666ffffff", "871fb4662ffffff", "871fb4660ffffff"
# "861fb4667ffffff", "861fb4677ffffff", "861fb466fffffff", "861fb4647ffffff", "861fb4297ffffff", "861fb474fffffff", "861fb475fffffff", "861fb462fffffff"
parent_hexes = ["861fb4667ffffff", "861fb4677ffffff", "861fb466fffffff", "861fb4647ffffff", "861fb475fffffff"]

# Generate the list of children at resolution 10
target_resolution = 10
start_valid_h3 = set()
end_valid_h3 = set()

for parent in parent_hexes:
    children = h3.cell_to_children(parent, target_resolution)  # <-- NON serve compact
    for child in children:
        start_valid_h3.add(child)
        end_valid_h3.add(child)
mask = (
    (od_matrix_first["start_h3"].isin(start_valid_h3))
    & (od_matrix_first["end_h3"].isin(end_valid_h3))
)
od_matrix_first = od_matrix_first[mask].copy()
print(f"Number of filtered rows: {len(od_matrix_first):,}")

Number of filtered rows: 23,264


In [6]:
od_matrix = od_matrix_first.copy()

In [7]:
filtered_df = df_merged.merge(
    od_matrix[['start_h3', 'end_h3']],
    on=['start_h3', 'end_h3'],
    how='inner'
)

In [10]:
class H3TreeNode:
    """H3 hierarchy tree node"""
    def __init__(self, h3_id: str, resolution: int, total_weight: int = 0, count: int = 0):
        self.h3_id = h3_id
        self.resolution = resolution
        self.total_weight = total_weight
        self.count = count
        self.children: Dict[str, 'H3TreeNode'] = {}
        self.parent: Optional['H3TreeNode'] = None
    
    def add_child(self, child: 'H3TreeNode'):
        """Adds a child to the node"""
        self.children[child.h3_id] = child
        child.parent = self
    
    def add_weight(self, weight: int):
        self.total_weight += weight
        if self.parent:
            self.parent.add_weight(weight)
    
    def add_count(self, count: int):
        self.count += count
        if self.parent:
            self.parent.add_count(count)
    
    def __repr__(self):
        return (f"H3Node(id={self.h3_id}, res={self.resolution}, "
                f"total_weight={self.total_weight}, count={self.count}, children={len(self.children)})")

In [None]:
class H3HierarchicalTree:
    def __init__(self, od_matrix: pd.DataFrame, target_resolution: int = 11, hex_column: str = 'start_h3'):
        self.od_matrix = od_matrix
        self.target_resolution = target_resolution
        self.hex_column = hex_column  # 'start_h3' o 'end_h3'
        self.nodes: Dict[str, H3TreeNode] = {}
        self.root = None
        self.min_resolution = None  # Sarà calcolato dinamicamente
        
    def get_all_hexagons(self) -> Set[str]:
        """Extracts all unique hexagons from the specified column"""
        return set(self.od_matrix[self.hex_column].unique())
    
    def get_resolution_coverage(self, hexagons: Set[str], target_res: int) -> Set[str]:
        """
        Obtains all target resolution hexagons covering the area defined by the input hexagons.
        """
        coverage_hexagons = set()
        
        for hex_id in hexagons:
            current_res = h3.get_resolution(hex_id)
            
            if current_res == target_res:
                coverage_hexagons.add(hex_id)
            elif current_res < target_res:
                # Dobbiamo espandere a risoluzione più alta
                children = self._get_all_children_at_resolution(hex_id, target_res)
                coverage_hexagons.update(children)
            else:
                # Dobbiamo salire a risoluzione più bassa
                parent = h3.cell_to_parent(hex_id, target_res)
                coverage_hexagons.add(parent)
        
        return coverage_hexagons
    
    def find_optimal_min_resolution(self, hexagons: Set[str]) -> int:
        """
        Find the highest resolution where there is still only one node covering all hexagons.
        """        
        # For each resolution from 0 to the target, count how many nodes are needed to cover all hexagons.
        resolution_stats = {}
        
        for resolution in range(0, self.target_resolution + 1):
            ancestors = set()
            for hex_id in hexagons:
                current_res = h3.get_resolution(hex_id)
                if current_res >= resolution:
                    ancestor = h3.cell_to_parent(hex_id, resolution)
                    ancestors.add(ancestor)
                else:
                    # add the hexagon directly if already with a lower resolution, 
                    ancestors.add(hex_id)
            
            resolution_stats[resolution] = len(ancestors)
            print(f"Resolution {resolution}: {len(ancestors)} nodes")
        
        # Find the highest resolution with count = 1
        optimal_resolution = 0
        for resolution in range(self.target_resolution, -1, -1):
            if resolution_stats[resolution] == 1:
                optimal_resolution = resolution
                break
        
        # print(f"optimal resolution: {optimal_resolution}")
        return optimal_resolution
    
    def get_siblings(self, node_id: str) -> List[str]:
        """Returns the h3_ids of the sibling nodes of node_id (other hexagons within the same parent)"""
        if node_id not in self.nodes:
            return []

        node = self.nodes[node_id]
        parent = node.parent
        if parent is None:
            return []

        siblings = [child.h3_id for child in parent.children.values() if child.h3_id != node_id]
        return siblings
    
    def get_parent(self, node_id: str) -> Optional[str]:
        """Returns the h3_id of the parent node, or None for root node."""
        if node_id not in self.nodes:
            return None
        node = self.nodes[node_id]
        if node.parent is None:
            return None
        return node.parent.h3_id
    
    def _get_all_children_at_resolution(self, hex_id: str, target_res: int) -> Set[str]:
        """Recursively obtains all children at a specific resolution"""
        current_res = h3.get_resolution(hex_id)
        
        if current_res == target_res:
            return {hex_id}
        elif current_res > target_res:
            return set()
        
        children = set()
        direct_children = h3.cell_to_children(hex_id, current_res + 1)
        
        for child in direct_children:
            children.update(self._get_all_children_at_resolution(child, target_res))
        
        return children
    
    def build_hierarchy_path(self, hex_id: str, min_resolution: int) -> List[str]:
        """Builds the hierarchical path from a hexagon to the minimum resolution"""
        path = [hex_id]
        current = hex_id
        current_res = h3.get_resolution(current)
        
        while current_res > min_resolution:
            parent = h3.cell_to_parent(current, current_res - 1)
            path.append(parent)
            current = parent
            current_res -= 1
        
        return path
    
    def create_tree_structure(self):
        """Create the optimized tree structure"""
        target_hexagons = self.get_all_hexagons()
        
        # Get full coverage at the target resolution
        coverage_hexagons = self.get_resolution_coverage(target_hexagons, self.target_resolution)
        
        # Find the optimal minimum resolution (the highest one with count=1)
        self.min_resolution = self.find_optimal_min_resolution(coverage_hexagons)
        
        # print(f"resolution tree from {self.min_resolution} to {self.target_resolution}")
        
        # Build all hierarchical paths
        all_paths = []
        for hex_id in coverage_hexagons:
            path = self.build_hierarchy_path(hex_id, self.min_resolution)
            all_paths.append(path)
        
        for path in all_paths:
            for hex_id in path:
                if hex_id not in self.nodes:
                    resolution = h3.get_resolution(hex_id)
                    self.nodes[hex_id] = H3TreeNode(hex_id, resolution)
        
        # parent-child relation
        for path in all_paths:
            for i in range(len(path) - 1):
                child_id = path[i]
                parent_id = path[i + 1]
                self.nodes[parent_id].add_child(self.nodes[child_id])
        
        # Identify the root
        self.root = self.nodes[self._find_root_hexagon(coverage_hexagons, self.min_resolution)]
        
        return self
    
    def _find_root_hexagon(self, hexagons: Set[str], min_resolution: int) -> str:
        sample_hex = next(iter(hexagons))
        return h3.cell_to_parent(sample_hex, min_resolution)
    
    def populate_counts(self):
        # Group the two metrics by hexagon
        agg_df = self.od_matrix.groupby(self.hex_column).agg({'total_weight': 'sum', 'count': 'sum'}).reset_index()
        
        for _, row in agg_df.iterrows():
            hex_id = row[self.hex_column]
            weight = int(row['total_weight'])
            count = int(row['count'])
            
            target_hex = self._map_to_target_resolution(hex_id)
            
            if target_hex in self.nodes:
                self.nodes[target_hex].add_weight(weight)
                self.nodes[target_hex].add_count(count)
            else:
                print(f"hexagon {target_hex} not found")
        
        return self
    
    def _map_to_target_resolution(self, hex_id: str) -> str:
        """Map a hexagon to the target resolution"""
        current_res = h3.get_resolution(hex_id)
        
        if current_res == self.target_resolution:
            return hex_id
        elif current_res < self.target_resolution:
            # Take the first child available
            children = self._get_all_children_at_resolution(hex_id, self.target_resolution)
            return next(iter(children)) if children else hex_id
        else:
            return h3.cell_to_parent(hex_id, self.target_resolution)
    
    def get_tree_statistics(self) -> Dict:
        if not self.root:
            return {}
        
        stats = {
            'total_nodes': len(self.nodes),
            'root_resolution': self.root.resolution,
            'min_resolution': self.min_resolution,
            'target_resolution': self.target_resolution,
            'total_weight': self.root.total_weight,
            'nodes_by_resolution': defaultdict(int),
            'resolution_range': f"{self.min_resolution} → {self.target_resolution}"
        }
        
        for node in self.nodes.values():
            stats['nodes_by_resolution'][node.resolution] += 1
        
        return stats
    
    def print_tree(self, max_children_per_level: int = 10):
        if not self.root:
            print("Tree not built")
            return
        
        print(f"OPTIMIZED TREE STRUCTURE: resolution range {self.min_resolution} → {self.target_resolution}")
        
        def _print_node(node: H3TreeNode, depth: int = 0, is_last: bool = True, prefix: str = ""):
            connector = "└─ " if is_last else "├─ "
            print(f"{prefix}{connector}{node.h3_id} (res:{node.resolution}, total_weight:{node.total_weight}, count:{node.count}, children:{len(node.children)})")
            
            if is_last:
                child_prefix = prefix + "   "
            else:
                child_prefix = prefix + "│  "
            
            children_list = list(node.children.values())
            
            if len(children_list) <= max_children_per_level:
                for i, child in enumerate(children_list):
                    is_last_child = (i == len(children_list) - 1)
                    _print_node(child, depth + 1, is_last_child, child_prefix)
            else:
                for i in range(max_children_per_level):
                    child = children_list[i]
                    is_last_child = (i == max_children_per_level - 1) and (len(children_list) == max_children_per_level)
                    _print_node(child, depth + 1, is_last_child, child_prefix)
                
                remaining = len(children_list) - max_children_per_level
                print(f"{child_prefix}└─ ... and other {remaining} children with the same pattern")
        
        _print_node(self.root, 0, True, "")

def create_h3_hierarchical_tree(od_matrix_df: pd.DataFrame, target_resolution: int = 10, hex_column: str = 'start_h3'):
    """
    Create an optimized H3 hierarchical tree from the OD matrix dataset
    
    Args:
        od_matrix_df: DataFrame with column 'start_h3', 'end_h3', 'count'
        target_resolution: Target resolution for tree leaves
        hex_column: Column to be analyzed ('start_h3' o 'end_h3')
    
    Returns:
        H3HierarchicalTree: Constructed and optimized hierarchical tree
    """
    
    tree = H3HierarchicalTree(od_matrix_df, target_resolution, hex_column)
    tree.create_tree_structure()
    tree.populate_counts()
    
    stats = tree.get_tree_statistics()
    print(f"OPTIMIZED TREE STATISTICS ({hex_column.upper()})")
    print(stats)
    
    # Calculate savings in nodes
    total_resolutions_possible = target_resolution + 1  # da 0 a target
    resolutions_used = len(stats['nodes_by_resolution'])
    resolutions_saved = total_resolutions_possible - resolutions_used
    
    print(f"OPTIMIZATIONS:")
    print(f"Saved resolutions: {resolutions_saved}")
    print(f"Tree efficiency: {resolutions_used}/{total_resolutions_possible} livelli utilizzati")
    
    tree.print_tree()
    
    return tree

In [None]:
tree_start = create_h3_hierarchical_tree(od_matrix, target_resolution=10, hex_column='start_h3')

In [10]:
tree_end = create_h3_hierarchical_tree(od_matrix, target_resolution=10, hex_column='end_h3')

In [31]:
class Lattice2DCount:
    """
    OIGH adattata per H3HierarchicalTree
    k-anonimita' basata su count, total_weight viene aggregato insieme
    """
    def __init__(self, od_matrix, tree_start, tree_end, k, S):
        self.od_matrix = od_matrix.copy()
        self.tree_start = tree_start
        self.tree_end = tree_end
        self.k = k
        self.S = S
        self.total_vol = self.od_matrix['count'].sum()

        self.L_start = max(node.resolution for node in tree_start.nodes.values())
        self.L_end   = max(node.resolution for node in tree_end.nodes.values())

        self.nodes = {}
        self.add_node(self.L_start, self.L_end, parents=[])
        self.max_level_found = np.inf
        self.od_matrix_agg = None
        self.best_avg_class_size = np.inf

    def add_node(self, lvlo, lvld, parents):
        if (lvlo, lvld) not in self.nodes:
            node = LatticeNodeCount(lvlo, lvld, parents=parents, lattice=self)
            self.nodes[(lvlo, lvld)] = node
            if lvlo > self.tree_start.min_resolution:
                self.add_node(lvlo-1, lvld, [node])
                node.children.append(self.nodes[(lvlo-1, lvld)])
            if lvld > self.tree_end.min_resolution:
                self.add_node(lvlo, lvld-1, [node])
                node.children.append(self.nodes[(lvlo, lvld-1)])
        else:
            self.nodes[(lvlo, lvld)].parents += parents


class LatticeNodeCount:
    def __init__(self, lvlo, lvld, parents, lattice):
        self.lattice = lattice
        self.lvlo = lvlo
        self.lvld = lvld
        self.parents = parents
        self.children = []
        self.anonymous = None
        self.visited = False

    def evaluate(self):
        self.visited = True
        if self.lvlo + self.lvld <= self.lattice.max_level_found:

            if self.children and not self.children[0].visited:
                self.children[0].evaluate()

            if self.anonymous is None:
                od_matrix_agg = self.get_aggregation()

                self.avg_class_size = self.get_mean_agg_level(od_matrix_agg)
                self.suppr_vol = od_matrix_agg[od_matrix_agg['count'] < self.lattice.k]['count'].sum()

                if self.suppr_vol > self.lattice.S:
                    self.tag_unanonymous()
                else:
                    if self.lvlo + self.lvld == self.lattice.max_level_found:
                        if (self.lattice.od_matrix_agg is None) or (self.avg_class_size < self.lattice.best_avg_class_size):
                            self.lattice.od_matrix_agg = od_matrix_agg
                            self.lattice.best_avg_class_size = self.avg_class_size
                    else:
                        self.lattice.max_level_found = self.lvlo + self.lvld
                        self.lattice.od_matrix_agg = od_matrix_agg
                        self.lattice.best_avg_class_size = self.avg_class_size
                    self.tag_anonymous()

            if len(self.children) > 1 and not self.children[1].visited:
                self.children[1].evaluate()

    def tag_anonymous(self):
        if self.anonymous is None:
            self.anonymous = True
            for c in self.children:
                c.tag_anonymous()

    def tag_unanonymous(self):
        if self.anonymous is None:
            self.anonymous = False
            for p in self.parents:
                p.tag_unanonymous()

    def map_to_level(self, h, target_res, tree):
        node = tree.nodes.get(h)
        if node is None:
            return h
        while node and node.resolution > target_res:
            node = node.parent
        return node.h3_id if node else h

    def get_aggregation(self):
        df = self.lattice.od_matrix.copy()

        # aggrega start e end
        df['start_gen'] = df['start_h3'].apply(lambda h: self.map_to_level(h, self.lvlo, self.lattice.tree_start))
        df['end_gen']   = df['end_h3'].apply(lambda h: self.map_to_level(h, self.lvld, self.lattice.tree_end))

        # aggrega count e total_weight
        agg = df.groupby(['start_gen', 'end_gen']).agg({
            'count': 'sum',          # k-anonimity basata su count
            'total_weight': 'sum'    # aggregazione total_weight
        }).reset_index()

        return agg

    def get_mean_agg_level(self, od_matrix_agg):
        def weighted_res(h3_col, tree):
            levels = []
            for h in h3_col:
                node = tree.nodes.get(h)
                levels.append(node.resolution if node else tree.min_resolution)
            return levels

        od_matrix_agg['res_o'] = weighted_res(od_matrix_agg['start_gen'], self.lattice.tree_start)
        od_matrix_agg['res_d'] = weighted_res(od_matrix_agg['end_gen'], self.lattice.tree_end)

        od_matrix_agg['w_res_o'] = od_matrix_agg['res_o'] * od_matrix_agg['count']
        od_matrix_agg['w_res_d'] = od_matrix_agg['res_d'] * od_matrix_agg['count']

        mean_vals = od_matrix_agg[od_matrix_agg['count'] >= self.lattice.k].agg({
            'w_res_o':'sum', 'w_res_d':'sum', 'count':'sum'
        })

        if mean_vals['count'] == 0:
            return np.inf

        return (mean_vals['w_res_o'] + mean_vals['w_res_d']) / mean_vals['count']


def oigh(od_matrix, tree_start, tree_end, k, S):
    lat = Lattice2DCount(od_matrix, tree_start, tree_end, k, S)
    lat.nodes[(lat.L_start, lat.L_end)].evaluate()
    return lat.od_matrix_agg

In [None]:
od_matrix_generalized = oigh(od_matrix, tree_start, tree_end, k=10, S=0)

In [None]:
od_matrix_generalized

In [16]:
class GeneralizedH3Visualizer:
    def __init__(self, od_matrix, center_lat=48.8566, center_lon=2.3522):
        """
        Visualize the generalized OD matrix H3 on Folium..

        Args:
            od_matrix: DataFrame with columns ['start_h3', 'end_h3', 'count']
            center_lat, center_lon: map center coordinates
        """
        self.od_matrix = od_matrix
        self.center_lat = center_lat
        self.center_lon = center_lon
        
        # Totali per origine e destinazione
        self.origin_flows = od_matrix.groupby('start_gen')['count'].sum().to_dict()
        self.dest_flows = od_matrix.groupby('end_gen')['count'].sum().to_dict()
    
    def _h3_to_geojson(self, h3_id):
        """Convert H3 to GeoJSON"""
        boundary = h3.cell_to_boundary(h3_id)
        coords = [[[lon, lat] for lat, lon in boundary]]  # GeoJSON vuole lon, lat
        return {
            "type": "Feature",
            "geometry": {"type": "Polygon", "coordinates": coords},
            "properties": {"h3_id": h3_id, "resolution": h3.get_resolution(h3_id)}
        }
    
    def create_map(self, max_hexagons=100, alpha=0.6, zoom_start=10):
        """Create the Folium map with origin and destination hexagons"""
        m = folium.Map(location=[self.center_lat, self.center_lon], zoom_start=zoom_start, tiles='OpenStreetMap')
        
        # Layer Origin (blue)
        origins_sorted = sorted(self.origin_flows.items(), key=lambda x: x[1], reverse=True)[:max_hexagons]
        origin_layer = folium.FeatureGroup(name="Origin (blue)", show=True)
        if origins_sorted:
            min_flow, max_flow = min(v for _, v in origins_sorted), max(v for _, v in origins_sorted)
            for h3_id, count in origins_sorted:
                geojson = self._h3_to_geojson(h3_id)
                intensity = (count - min_flow) / (max_flow - min_flow) if max_flow > min_flow else 1.0
                blue_intensity = int(255 * (0.3 + 0.7*intensity))
                fill_color = f"#{0:02x}{0:02x}{blue_intensity:02x}"
                folium.GeoJson(
                    geojson,
                    style_function=lambda x, fill_color=fill_color: {
                        'fillColor': fill_color,
                        'color': 'darkblue',
                        'weight': 1,
                        'fillOpacity': alpha,
                        'opacity': 0.8
                    },
                    tooltip=f"{count} viaggi"
                ).add_to(origin_layer)
        origin_layer.add_to(m)
        
        # Layer Destination (red)
        dest_sorted = sorted(self.dest_flows.items(), key=lambda x: x[1], reverse=True)[:max_hexagons]
        dest_layer = folium.FeatureGroup(name="Destination (red)", show=True)
        if dest_sorted:
            min_flow, max_flow = min(v for _, v in dest_sorted), max(v for _, v in dest_sorted)
            for h3_id, count in dest_sorted:
                geojson = self._h3_to_geojson(h3_id)
                intensity = (count - min_flow) / (max_flow - min_flow) if max_flow > min_flow else 1.0
                red_intensity = int(255 * (0.3 + 0.7*intensity))
                fill_color = f"#{red_intensity:02x}{0:02x}{0:02x}"
                folium.GeoJson(
                    geojson,
                    style_function=lambda x, fill_color=fill_color: {
                        'fillColor': fill_color,
                        'color': 'darkred',
                        'weight': 1,
                        'fillOpacity': alpha,
                        'opacity': 0.8
                    },
                    tooltip=f"{count} viaggi"
                ).add_to(dest_layer)
        dest_layer.add_to(m)
        
        folium.LayerControl().add_to(m)
        return m

In [None]:
visualizer = GeneralizedH3Visualizer(od_matrix_generalized)
mappa = visualizer.create_map(max_hexagons=2000000)
mappa

In [None]:
def compute_discernability_and_cavg(df: pd.DataFrame, k: int, suppressed_count: int = 0) -> dict:
    """
    compute C_DM e C_AVG for dataset OD generalization.
    
    Args:
        df: DataFrame with column ['start_h3', 'end_h3', 'count']
        k: for k-anonimity
        suppressed_count: number of OD pairs suppressed (optional)
    
    """
    counts = df['count'].values
    total_records = counts.sum() + suppressed_count
    total_equiv_classes = len(counts) + suppressed_count
    
    # C_DM: somma dei quadrati dei count >= k
    k_anonymous_counts = counts[counts >= k]
    c_dm_gen = np.sum(k_anonymous_counts**2)
    
    # Penalità per record soppressi
    suppression_penalty = suppressed_count * counts.sum()  # o totale record, a seconda della definizione
    c_dm = c_dm_gen + suppression_penalty
    
    # C_AVG: (total_records / total_equiv_classes) / k
    c_avg = (total_records / total_equiv_classes) / k if total_equiv_classes > 0 else float('inf')
    
    return {
        'C_DM': c_dm,
        'C_AVG': c_avg,
        'total_records': total_records,
        'total_equivalence_classes': total_equiv_classes,
        'k': k
    }

In [None]:
metrics = compute_discernability_and_cavg(od_matrix_generalized, k=10, suppressed_count=0)

In [None]:
from geopy.distance import geodesic

def calculate_generalization_distance_metric(df: pd.DataFrame, od_matrix_generalized: pd.DataFrame) -> Dict:

   # mapping from original hexagons to generalized hexagons
   start_original_to_generalized = {}
   end_original_to_generalized = {}
   
   # Get all unique generalized hexagons
   generalized_start_h3 = set(od_matrix_generalized['start_gen'].unique())
   generalized_end_h3 = set(od_matrix_generalized['end_gen'].unique())
   
   # For each original hexagon, find the corresponding generalized hexagon.
   unique_start_h3 = df['start_h3'].unique()
   unique_end_h3 = df['end_h3'].unique()
   
   for original_h3 in unique_start_h3:
       generalized_h3 = find_generalized_hexagon(original_h3, generalized_start_h3)
       if generalized_h3:
           start_original_to_generalized[original_h3] = generalized_h3
   
   for original_h3 in unique_end_h3:
       generalized_h3 = find_generalized_hexagon(original_h3, generalized_end_h3)
       if generalized_h3:
           end_original_to_generalized[original_h3] = generalized_h3
   
   # Calculate distances for origin points
   start_distances = []
   start_coords = []
   
   for idx, row in df.iterrows():
       original_h3 = row['start_h3']
       original_coords = (row['start_lat'], row['start_lon'])
       
       if original_h3 in start_original_to_generalized:
           generalized_h3 = start_original_to_generalized[original_h3]
           generalized_coords = h3.cell_to_latlng(generalized_h3)
           
           distance = geodesic(original_coords, generalized_coords).meters
           
           start_distances.append(distance)
           start_coords.append({
               'original_h3': original_h3,
               'generalized_h3': generalized_h3,
               'original_coords': original_coords,
               'generalized_coords': generalized_coords,
               'distance': distance
           })
   
   # Calculate distances for destination points
   end_distances = []
   end_coords = []
   
   for idx, row in df.iterrows():
       original_h3 = row['end_h3']
       original_coords = (row['end_lat'], row['end_lon'])
       
       if original_h3 in end_original_to_generalized:
           generalized_h3 = end_original_to_generalized[original_h3]
           generalized_coords = h3.cell_to_latlng(generalized_h3)
           
           distance = geodesic(original_coords, generalized_coords).meters
               
           end_distances.append(distance)
           end_coords.append({
               'original_h3': original_h3,
               'generalized_h3': generalized_h3,
               'original_coords': original_coords,
               'generalized_coords': generalized_coords,
               'distance': distance
           })
   
   results = {
       'start_distances': {
           'mean': np.mean(start_distances) if start_distances else 0,
           'median': np.median(start_distances) if start_distances else 0,
           'std': np.std(start_distances) if start_distances else 0,
           'min': np.min(start_distances) if start_distances else 0,
           'max': np.max(start_distances) if start_distances else 0,
           'count': len(start_distances)
       },
       'end_distances': {
           'mean': np.mean(end_distances) if end_distances else 0,
           'median': np.median(end_distances) if end_distances else 0,
           'std': np.std(end_distances) if end_distances else 0,
           'min': np.min(end_distances) if end_distances else 0,
           'max': np.max(end_distances) if end_distances else 0,
           'count': len(end_distances)
       },
       'overall': {
           'mean': np.mean(start_distances + end_distances) if (start_distances or end_distances) else 0,
           'median': np.median(start_distances + end_distances) if (start_distances or end_distances) else 0,
           'std': np.std(start_distances + end_distances) if (start_distances or end_distances) else 0,
           'total_points': len(start_distances) + len(end_distances)
       },
       'mappings': {
           'start_original_to_generalized': start_original_to_generalized,
           'end_original_to_generalized': end_original_to_generalized
       },
       'detailed_coords': {
           'start': start_coords,
           'end': end_coords
       }
   }
    return results

def find_generalized_hexagon(original_h3: str, generalized_hexagons: set) -> str:
   """
   Find the generalized hexagon corresponding to an original hexagon
   """
   # If the hexagon is already in the list of generalized hexagons
   if original_h3 in generalized_hexagons:
       return original_h3
   
   # Otherwise, search among all generalized hexagons to see if the original is their descendant.
   for generalized_h3 in generalized_hexagons:
       if is_descendant_of(original_h3, generalized_h3):
           return generalized_h3
   
   return None

def is_descendant_of(child_h3: str, parent_h3: str) -> bool:
   """
   Controlla se child_h3 è discendente di parent_h3
   """
   child_res = h3.get_resolution(child_h3)
   parent_res = h3.get_resolution(parent_h3)
   
   if parent_res >= child_res:
       return False
   
   current = child_h3
   while h3.get_resolution(current) > parent_res:
       current = h3.cell_to_parent(current, h3.get_resolution(current) - 1)
   
   return current == parent_h3

def analyze_generalization_impact(results: Dict) -> None:
   """
   Analyze the impact of generalization on distances
   """
   
   all_distances = []
   for coord in results['detailed_coords']['start'] + results['detailed_coords']['end']:
       all_distances.append(coord['distance'])
   
   if all_distances:
       percentiles = [25, 50, 75, 90, 95, 99]
       print("Distance distribution")
       for p in percentiles:
           value = np.percentile(all_distances, p)
           print(f"{p} percentile: {value:.2f} meters")
   
   resolution_analysis = {}
   for coord in results['detailed_coords']['start'] + results['detailed_coords']['end']:
       original_res = h3.get_resolution(coord['original_h3'])
       generalized_res = h3.get_resolution(coord['generalized_h3'])
       
       key = f"{original_res}→{generalized_res}"
       if key not in resolution_analysis:
           resolution_analysis[key] = []
       resolution_analysis[key].append(coord['distance'])
   
   print("Resolution changes")
   for resolution_change, distances in resolution_analysis.items():
       mean_dist = np.mean(distances)
       count = len(distances)
       print(f"{resolution_change}: {mean_dist:.2f}m (n={count})")

In [None]:
distance_results = calculate_generalization_distance_metric(
   df=filtered_df, 
   od_matrix_generalized=od_matrix_generalized
)

analyze_generalization_impact(distance_results)

In [None]:
class GeneralizationMetric:
    """
    Ḡ = (1/V+) × Σ(|o| + |d|) × v_{o→d}
    """
    def __init__(self, k_threshold: int = 10):
        self.k_threshold = k_threshold

    def calculate_generalization_error(self, od_matrix_generalized: pd.DataFrame, od_matrix: pd.DataFrame) -> float:
        # generalized -> number of original cells
        origin_counts = self._build_hexagon_counts(
            od_matrix_generalized, od_matrix, column_gen="start_gen", column_orig="start_h3"
        )
        destination_counts = self._build_hexagon_counts(
            od_matrix_generalized, od_matrix, column_gen="end_gen", column_orig="end_h3"
        )

        total_volume_anonymous = 0
        weighted_count_sum = 0

        for _, row in od_matrix_generalized.iterrows():
            flow_value = row["count"]
            if flow_value >= self.k_threshold:
                origin_h3 = row["start_gen"]
                dest_h3   = row["end_gen"]

                origin_count = origin_counts.get(origin_h3, 1)
                dest_count   = destination_counts.get(dest_h3, 1)

                total_volume_anonymous += flow_value
                weighted_count_sum += (origin_count + dest_count) * flow_value

        return weighted_count_sum / total_volume_anonymous if total_volume_anonymous > 0 else 0.0

    def _build_hexagon_counts(
        self, od_matrix_generalized: pd.DataFrame, od_matrix: pd.DataFrame, 
        column_gen: str, column_orig: str
    ) -> dict:
        """
        Count how many original hexagons belong to each generalized hexagon.
        """
        generalized_hexagons = od_matrix_generalized[column_gen].unique()
        original_hexagons = od_matrix[column_orig].unique()

        counts = {}
        for gen_hex in generalized_hexagons:
            target_res = h3.get_resolution(gen_hex)

            # Find all parents of originals at target resolution
            parent_series = [h3.cell_to_parent(h, target_res) for h in original_hexagons]

            # Count how many times the parent == gen_hex appears
            count = sum(1 for p in parent_series if p == gen_hex)
            counts[gen_hex] = max(count, 1)  # fallback to 1

        return counts

In [None]:
metric = GeneralizationMetric(k_threshold=10)
error = metric.calculate_generalization_error(od_matrix_generalized, od_matrix)
print(f"Average generalization error Ḡ: {error:.3f}")

In [None]:
def fast_reconstruction_loss(original_od_df: pd.DataFrame,
                             od_matrix_generalized: pd.DataFrame) -> float:
    """
    reconstruction loss:first_seen
    E = (1/V) * Σ |ṽ_o→d - v_o→d|
    """
    
    # Crea un dizionario per accesso veloce ai flussi generalizzati
    generalized_flows = {
        (row['start_gen'], row['end_gen']): row['count']
        for _, row in od_matrix_generalized.iterrows()
    }
    
    total_volume = 0
    total_abs_error = 0

    gen_start_hexes = od_matrix_generalized['start_gen'].unique()
    gen_end_hexes   = od_matrix_generalized['end_gen'].unique()

    for _, row in original_od_df.iterrows():
        start_h3 = row['start_h3']
        end_h3   = row['end_h3']
        true_count = row['count']
        
        # Trova gli esagoni generalizzati corrispondenti
        gen_start = _find_generalized_parent(start_h3, gen_start_hexes)
        gen_end   = _find_generalized_parent(end_h3, gen_end_hexes)
        
        if gen_start is None or gen_end is None:
            continue
            
        # Search for the corresponding generalized flow
        gen_key = (gen_start, gen_end)
        gen_count = generalized_flows.get(gen_key, 0)
        
        # Calculate the absolute error
        total_abs_error += abs(gen_count - true_count)
        total_volume += true_count

    return total_abs_error / total_volume if total_volume > 0 else 0.0


def _find_generalized_parent(original_h3: str, generalized_hexagons: list) -> str:
    """
    Find the generalized hexagon that contains the original hexagon.
    """
    original_res = h3.get_resolution(original_h3)
    
    for gen_hex in generalized_hexagons:
        gen_res = h3.get_resolution(gen_hex)
        
        if gen_res <= original_res:
            parent = h3.cell_to_parent(original_h3, gen_res)
            if parent == gen_hex:
                return gen_hex
    
    return None


In [None]:
loss = fast_reconstruction_loss(
    original_od_df=od_matrix,
    od_matrix_generalized=od_matrix_generalized
)
print(f"Reconstruction Loss: {loss:.6f}")

### Metrics with weights

In [11]:
def compute_discernability_and_cavg(df: pd.DataFrame, k: int, suppressed_count: int = 0) -> dict:
    """
    Args:
        df: DataFrame with ['start_h3', 'end_h3', 'count']
        k: for k-anonimity
        suppressed_count: number of OD pairs suppressed (optional)
    
    Returns:
        dict con C_DM, C_AVG, total number of records and equivalence classes
    """
    counts = df['total_weight'].values
    total_records = counts.sum() + suppressed_count
    total_equiv_classes = len(counts) + suppressed_count
    
    k_anonymous_counts = counts[counts >= k]
    c_dm_gen = np.sum(k_anonymous_counts**2)
    
    # Penalty for suppressed records
    suppression_penalty = suppressed_count * counts.sum()  # o totale record, a seconda della definizione
    c_dm = c_dm_gen + suppression_penalty
    
    # C_AVG: (total_records / total_equiv_classes) / k
    c_avg = (total_records / total_equiv_classes) / k if total_equiv_classes > 0 else float('inf')
    
    return {
        'C_DM': c_dm,
        'C_AVG': c_avg,
        'total_records': total_records,
        'total_equivalence_classes': total_equiv_classes,
        'k': k
    }

In [None]:
metrics = compute_discernability_and_cavg(od_matrix_generalized, k=10*media_peso, suppressed_count=0)

In [None]:
class GeneralizationMetric:
    """
    Ḡ = (1/V+) × Σ(|o| + |d|) × v_{o→d}
    """
    def __init__(self, k_threshold: int = 10):
        self.k_threshold = k_threshold

    def calculate_generalization_error(self, od_matrix_generalized: pd.DataFrame, od_matrix: pd.DataFrame) -> float:
        # generalized -> number of original cells
        origin_counts = self._build_hexagon_counts(
            od_matrix_generalized, od_matrix, column_gen="start_gen", column_orig="start_h3"
        )
        destination_counts = self._build_hexagon_counts(
            od_matrix_generalized, od_matrix, column_gen="end_gen", column_orig="end_h3"
        )

        total_volume_anonymous = 0
        weighted_count_sum = 0

        for _, row in od_matrix_generalized.iterrows():
            flow_value = row["total_weight"]
            if flow_value >= self.k_threshold:
                origin_h3 = row["start_gen"]
                dest_h3   = row["end_gen"]

                origin_count = origin_counts.get(origin_h3, 1)
                dest_count   = destination_counts.get(dest_h3, 1)

                total_volume_anonymous += flow_value
                weighted_count_sum += (origin_count + dest_count) * flow_value

        return weighted_count_sum / total_volume_anonymous if total_volume_anonymous > 0 else 0.0

    def _build_hexagon_counts(
        self, od_matrix_generalized: pd.DataFrame, od_matrix: pd.DataFrame, 
        column_gen: str, column_orig: str
    ) -> dict:
        """
        Count how many original hexagons belong to each generalized hexagon.
        """
        generalized_hexagons = od_matrix_generalized[column_gen].unique()
        original_hexagons = od_matrix[column_orig].unique()

        counts = {}
        for gen_hex in generalized_hexagons:
            target_res = h3.get_resolution(gen_hex)

            # Trova tutti i parent degli originali alla risoluzione target
            parent_series = [h3.cell_to_parent(h, target_res) for h in original_hexagons]

            # Conta quante volte compare il parent == gen_hex
            count = sum(1 for p in parent_series if p == gen_hex)
            counts[gen_hex] = max(count, 1)  # fallback a 1

        return counts

In [None]:
metric = GeneralizationMetric(k_threshold=10*media_peso)
error = metric.calculate_generalization_error(od_matrix_generalized, od_matrix)

In [None]:
def fast_reconstruction_loss(original_od_df: pd.DataFrame,
                             od_matrix_generalized: pd.DataFrame) -> float:
    """
    reconstruction loss:
    E = (1/V) * Σ |ṽ_o→d - v_o→d|
    """
    
    
    generalized_flows = {
        (row['start_gen'], row['end_gen']): row['total_weight']
        for _, row in od_matrix_generalized.iterrows()
    }
    
    total_volume = 0
    total_abs_error = 0

    gen_start_hexes = od_matrix_generalized['start_gen'].unique()
    gen_end_hexes   = od_matrix_generalized['end_gen'].unique()

    for _, row in original_od_df.iterrows():
        start_h3 = row['start_h3']
        end_h3   = row['end_h3']
        true_count = row['total_weight']
        
        # Find the corresponding generalized hexagons
        gen_start = _find_generalized_parent(start_h3, gen_start_hexes)
        gen_end   = _find_generalized_parent(end_h3, gen_end_hexes)
        
        if gen_start is None or gen_end is None:
            continue
            
        # Search for the corresponding generalized flow
        gen_key = (gen_start, gen_end)
        gen_count = generalized_flows.get(gen_key, 0)
        
        # the absolute error
        total_abs_error += abs(gen_count - true_count)
        total_volume += true_count

    return total_abs_error / total_volume if total_volume > 0 else 0.0


def _find_generalized_parent(original_h3: str, generalized_hexagons: list) -> str:
    """
    Find the generalized hexagon that contains the original hexagon.
    """
    original_res = h3.get_resolution(original_h3)
    
    for gen_hex in generalized_hexagons:
        gen_res = h3.get_resolution(gen_hex)
        
        if gen_res <= original_res:
            parent = h3.cell_to_parent(original_h3, gen_res)
            if parent == gen_hex:
                return gen_hex
    
    return None

In [None]:
loss = fast_reconstruction_loss(
    original_od_df=od_matrix,
    od_matrix_generalized=od_matrix_generalized
)