In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['JAVA_HOME'] = '/Library/Java/JavaVirtualMachines/openjdk-17.jdk/Contents/Home'
os.environ['PATH'] = f"{os.environ['JAVA_HOME']}/bin:{os.environ['PATH']}"

In [None]:
# Imports
import pandas as pd
import geopandas as gpd
from pathlib import Path
import deepdiff

from cider.homelocation.core import get_home_locations, get_accuracy
from cider.schemas import CallDataRecordData, AntennaData
from cider.homelocation.schemas import GeographicUnit, GetHomeLocationAlgorithm
from cider.homelocation.plotting import make_location_map
from cider.utils import validate_dataframe
# WILL BE DEPRECATED IN FUTURE VERSIONS
import sys
sys.path.insert(0,'..')
from deprecated.home_location import DataStore, HomeLocator

In [None]:
# Load data
DATA_DIR = Path('../synthetic_data/')
cdr = pd.read_csv(DATA_DIR / 'cdr.csv')
# outgoing = cdr[['caller_id', 'caller_antenna', 'timestamp']]\
#     .rename({'caller_antenna':'caller_antenna_id'}, axis=1)
# incoming = cdr[['recipient_id', 'recipient_antenna', 'timestamp']]\
#     .rename({'recipient_antenna':'recipient_antenna_id'}, axis=1)
# cdr = pd.concat([outgoing, incoming])
cdr.rename(columns={
    "caller_antenna": "caller_antenna_id", 
    "recipient_antenna": "recipient_antenna_id",
    "txn_type": "transaction_type",
    "international": "transaction_scope"}, inplace=True)
cdr = cdr[~cdr.caller_antenna_id.isna()]

antennas = pd.read_csv(DATA_DIR / 'antennas.csv').dropna()
home_locations_gt = pd.read_csv(DATA_DIR / 'home_locations.csv').rename(
    columns={
        "subscriber_id": "caller_id",
        "antenna_id": "caller_antenna_id",
        "regions": "region"})
shapefile = gpd.read_file('../synthetic_data/regions.geojson')

In [None]:
shapefile

In [None]:
# WILL BE DEPRECATED IN FUTURE VERSIONS
# We need this right now to load the spark session, and also to compare old vs new results
config_file = '../configs/config_new.yml'
datastore = DataStore(config_file)

setattr(datastore, "outputs", str(datastore.cfg.path.working.directory_path))

spark = datastore.spark


In [None]:
# Validate data using schemas
# CAN ignore invalid data errors for now -- rerun after fixing data generation scripts
validate_dataframe(cdr, CallDataRecordData, check_data_points=True)
validate_dataframe(antennas, AntennaData, check_data_points=True)


In [None]:
# Run home location inference
geographic_unit = GeographicUnit.SHAPEFILE
algorithm = GetHomeLocationAlgorithm.COUNT_TRANSACTIONS

homes = get_home_locations(
    cdr_data=cdr, # todo: ideally we should validate the synthetic data using the schemas
    antenna_data=antennas, # todo: ideally we should validate the synthetic data using the schemas
    shapefile_data=shapefile,
    geographic_unit=geographic_unit,
    algorithm=algorithm,
    additional_columns_to_keep=[]
)
homes.head(10)

In [None]:
(homes['caller_antenna_id'] == home_locations_gt['caller_antenna_id'])

In [None]:
# Calculate accuracy
metrics_table = get_accuracy(homes, home_locations_gt, column_to_merge_on="caller_id", column_to_measure_on="caller_antenna_id")
metrics_table

In [None]:
# Plotting results

merged = homes.merge(antennas, left_on='caller_antenna_id', right_on='antenna_id', how='left')
homes_gdf = gpd.GeoDataFrame(
    merged,
    geometry=gpd.points_from_xy(merged.longitude, merged.latitude),
    crs="EPSG:4326"
)
fig = make_location_map(
    inferred_home_locations=homes_gdf,
    boundaries_shapefile=shapefile,
    column_to_plot_label='caller_antenna_id',
    column_to_plot_markersize=algorithm.value,
)

In [None]:
# Compare with previous way of doing things:
# WILL BE DEPRECATED IN FUTURE VERSIONS

# Reload data to ensure no contamination
old_cdr = pd.read_csv('../synthetic_data/cdr.csv')
old_outgoing = old_cdr[['caller_id', 'caller_antenna', 'timestamp']]\
    .rename({'caller_id':'name', 'caller_antenna':'antenna_id'}, axis=1)
old_incoming = old_cdr[['recipient_id', 'recipient_antenna', 'timestamp']]\
    .rename({'recipient_id':'name', 'recipient_antenna':'antenna_id'}, axis=1)
old_cdr = pd.concat([old_outgoing, old_incoming])
old_antennas = pd.read_csv('../synthetic_data/antennas.csv')
old_cdr = old_cdr.merge(old_antennas, on='antenna_id', how='inner')
old_home_locations = pd.read_csv('../synthetic_data/home_locations.csv')
old_shapefile = gpd.read_file('../synthetic_data/regions.geojson')

In [None]:
# Infer home locations using old method
algo = algorithm.value
home_locator = HomeLocator(datastore)
old_homes = home_locator.get_home_locations(geo="antenna_id", algo=algo).rename({'subscriber_id':'name'}, axis=1)

In [None]:
# Compare old and new
# Do some renaming to align columns
old_homes.rename(
    columns={
        "name": "caller_id",
        "antenna_id": "caller_antenna_id"}, 
        inplace=True)
ddiff = deepdiff.DeepDiff(
    old_homes.sort_values(by=['caller_id']).reset_index(drop=True),
    homes.sort_values(by=['caller_id']).reset_index(drop=True),
    ignore_order=True
)
ddiff # <-- should be empty dictionary if two things are the same