# Test the new structure of grid search

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import subprocess
import sys
def get_repo_root():
    """Get the root directory of the repo."""
    dir_in_repo = os.path.dirname(os.path.abspath('__file__')) # os.getcwd()
    return subprocess.check_output('git rev-parse --show-toplevel'.split(),
                                   cwd=dir_in_repo,
                                   universal_newlines=True).rstrip()
sys.path.append(get_repo_root())
ROOT_dir = get_repo_root()

In [2]:
# load libraries
import gs_model
import json
import multiprocessing as mp
import pprint
import pandas as pd

In [3]:
class RegionParaSearch:
    def __init__(self, res=None, region=None, rg=None, visits=None):
        self.res = res
        self.region = region
        self.rg = rg
        self.visits = visits

    def region_data_load(self):
        self.res = ROOT_dir + '/results/gridsearch-n_' + self.region + '.txt'
        rg_ = gs_model.RegionDataPrep(region=self.region)
        rg_.load_zones_odm()
        rg_.load_geotweets()
        rg_.kl_baseline_compute()
        self.rg = rg_
        self.visits = gs_model.VisitsGeneration(region=self.region, bbox=self.rg.bbox,
                                                zones=self.rg.zones, odm=self.rg.gt_odm,
                                                distances=self.rg.distances,
                                                distance_quantiles=self.rg.distance_quantiles, gt_dms=self.rg.dms)

    def gs_para(self, p, gamma, beta):
        # parallelize the generation of visits over days
        pool = mp.Pool(mp.cpu_count())
        visits_list = pool.starmap(self.visits.visits_gen_chunk,
                                   [(self.rg.tweets_calibration, p, gamma, beta, x) for x in [7] * 20])
        visits_total = pd.concat(visits_list).set_index('userid')
        pool.close()
        print('Visits generated:', len(visits_total))
        divergence_measure = self.visits.visits2measure(visits=visits_total, home_locations=self.rg.home_locations)
        # append the result to the gridsearch file
        dic = {'region': self.region, 'p': p, 'beta': beta, 'gamma': gamma,
               'kl-baseline': self.rg.kl_baseline, 'kl': divergence_measure}
        pprint.pprint(dic)
        with open(self.res, 'a') as outfile:
            json.dump(dic, outfile)
            outfile.write('\n')
        return -divergence_measure

## 1 Load data according to the specified region

In [None]:
# load region data
region='netherlands'
gs = RegionParaSearch(region=region)
gs.region_data_load()

In [4]:
import netherlands
rg = netherlands.GroundTruthLoader()
rg.load_zones()
rg.load_odm()

origin_zip  dest_zip
0           0           1.921383e+08
            1011        3.320952e+05
            1012        3.363486e+05
            1013        1.961443e+05
            1015        5.766617e+05
Name: weight_trip, dtype: float64
zone  zone
1011  1011    2.084926e+06
      1012    8.106812e+05
      1013    6.175562e+04
      1014    0.000000e+00
      1015    0.000000e+00
Name: weight_trip, dtype: float64


## 2 Generate visits

In [11]:
gs.gs_para(0.010000000000000009, 0.01, 0.01)

Visits generated: 1566562
removed 6572 visits due to sampling bbox
Convering visits to zone CRS
Aligning region-visits to zones...
removed 5177 region-visits due to missing zone geom
Aligning point-visits to zones...
removed 1784 point-visits due to missing zone geom
1553029 visits left after alignment
Creating odm...
{'beta': 0.01,
 'gamma': 0.01,
 'kl': 1.710174804553334,
 'kl-baseline': 0.31464927288956873,
 'p': 0.010000000000000009,
 'region': 'sweden-national'}


-1.710174804553334