In [1]:
# %pip install dataclasses
# %pip install matching-ds-tools
# %pip install pytorch-lightning

In [2]:
import json
import datetime
import re
import pickle
%matplotlib inline
%load_ext autoreload
%autoreload 2

import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from queryrunner_client import Client
USER_EMAIL = 'ssadeghi@uber.com'
qclient = Client(user_email=USER_EMAIL)
CONSUMER_NAME = 'intelligentdispatch'

import os
import warnings
warnings.filterwarnings('ignore')
import multiprocessing
from joblib import Parallel, delayed
#num_cores = multiprocessing.cpu_count()  -- 48
n_cores = 4

In [3]:
from dataclasses import dataclass
import itertools
from typing import *
import numpy as np
import pandas as pd
from queryrunner_client import Client as QRClient
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import mdstk
from mdstk.data_fetcher.data_fetcher import DataFetcher
from mdstk.data_fetcher.cached_data_fetcher import CachedDataFetcher
import datetime
pd.set_option('display.max_columns', None)
import plotly.express as px
import plotly.figure_factory as ff
import h3
import h3pandas

In [4]:
"""
CASE 1: San Francisco
city_id = 1
vvid = 8
MAX_LAT = 38.19
MIN_LAT = 37.09
MAX_LNG = -121.55
MIN_LNG = -122.60
LAT_CENTER = 37.6
LON_CENTER = -122.2


CASE 2: Detroit
city_id = 50
vvid = 425
MAX_LAT = 42.89
MIN_LAT = 42.01
MAX_LNG = -82.68
MIN_LNG = -83.93
LAT_CENTER = 42.420149389121406
LON_CENTER = -83.15996619595755


Case 3: Philadelphia
city_id = 20
vvid = 663
MAX_LAT = 40.22
MIN_LAT = 39.74
MAX_LNG = -75.46
MIN_LNG = -74.86
LAT_CENTER = 39.95201837418434
LON_CENTER = -75.15727285438611

"""


'\nCASE 1: San Francisco\ncity_id = 1\nvvid = 8\nMAX_LAT = 38.19\nMIN_LAT = 37.09\nMAX_LNG = -121.55\nMIN_LNG = -122.60\nLAT_CENTER = 37.6\nLON_CENTER = -122.2\n\n\nCASE 2: Detroit\ncity_id = 50\nvvid = 425\nMAX_LAT = 42.89\nMIN_LAT = 42.01\nMAX_LNG = -82.68\nMIN_LNG = -83.93\nLAT_CENTER = 42.420149389121406\nLON_CENTER = -83.15996619595755\n\n\nCase 3: Philadelphia\ncity_id = 20\nvvid = 663\nMAX_LAT = 40.22\nMIN_LAT = 39.74\nMAX_LNG = -75.46\nMIN_LNG = -74.86\nLAT_CENTER = 39.95201837418434\nLON_CENTER = -75.15727285438611\n\n'

In [5]:
### INPUTS

prefix = 'unifiedQ'

city_id = 20
vvid = 663
MAX_LAT = 40.22
MIN_LAT = 39.74
MAX_LNG = -74.86
MIN_LNG = -75.46
LAT_CENTER = 39.95201837418434
LON_CENTER = -75.15727285438611
        
datestrs = [  # 1 week
    '2022-11-13',
    '2022-11-14',
    '2022-11-15',
    '2022-11-16',
    '2022-11-17',
    '2022-11-18',
    '2022-11-20',
    '2022-11-21',
    '2022-11-22',
    '2022-11-23',
    '2022-11-24',
    '2022-11-25',
    '2022-11-26',
    '2022-11-27',
    '2022-11-28',
    '2022-11-29',
    '2022-11-30',
    '2022-12-01',
    '2022-12-02',
    '2022-12-03',
    '2022-12-04',
    '2022-12-05',
    '2022-12-06',
    '2022-12-07',
    '2022-12-08',
]


## ML TRAINING

BATCH_SIZE = 64
ITERATIONS = 500000
DISCOUNT = 0.995
VALUE_DEGRADE_LEVEL = 0.85


In [6]:
idle_duration_seconds = np.log(VALUE_DEGRADE_LEVEL) / np.log(DISCOUNT)

### Query trip data

In [7]:
# data collection

QUERY = """
WITH
completed as (
SELECT
  ft.datestr as datestr,
  ft.city_id as city_id, 
  ft.trip_uuid as trip_uuid,
  ft.session_uuid as session_uuid,
  ft.driver_uuid as driver_uuid,
  ft.request_timestamp_local as local_time,
  ft.eta as eta,
  ft.fare as fare,
  ft.duration_min as duration_min,
  --   Driver origin information
  mez.driver_origin_lat as driver_origin_lat,
  mez.driver_origin_lng as driver_origin_lng,
  mez.pickup_lat as pickup_lat,
  mez.pickup_lng as pickup_lng,
  mez.dropoff_lat as dropoff_lat,
  mez.dropoff_lng as dropoff_lng
FROM
  (
    SELECT
      base.uuid as uuid,
      base.accepted_lat as driver_origin_lat,
      base.accepted_lng as driver_origin_lng,
      base.begintrip_lat as pickup_lat,
      base.begintrip_lng as pickup_lng,
      base.dropoff_lat as dropoff_lat,
      base.dropoff_lng as dropoff_lng
    FROM
      rawdata.schemaless_mezzanine_trips_rows
    WHERE
      datestr = '{datestr}'
      and base.city_id = {city_id}
      and LOWER(base.status) = 'completed'
  ) mez
  JOIN (
    SELECT
      --Request Information
      datestr,
      city_id,
      uuid as trip_uuid,
      session_id as session_uuid,
      driver_uuid,
      --Time Information,
      request_timestamp_local,
      request_timestamp_utc,
      eta,
      client_upfront_fare_usd as fare,
      --Distances and duration of the request,
      trip_duration_seconds / 60 as duration_min
    FROM
      restricted_dwh.fact_trip
    WHERE
      lower(global_product_name) = 'uberx'
      and lower(status) = 'completed'
      and city_id = {city_id}
      and datestr = '{datestr}'
--      and dispatch_type in (NULL, 'fifo')
  ) ft 
ON mez.uuid = ft.trip_uuid
),

idle as (
SELECT
sp.datestr as datestr,
sp.city_id as city_id,
'NA' as trip_uuid,
'NA' as session_uuid,
sp.earner_uuid as driver_uuid,
sp.start_timestamp.`local` as local_time,
CAST(0 as DOUBLE) as eta,
CAST(0 as DOUBLE) as fare,
CAST(0.5 as DOUBLE) as duration_min,
sp.location.lat as driver_origin_lat,
sp.location.lng as driver_origin_lng,
sp.location.lat as pickup_lat,
sp.location.lng as pickup_lng,
sp.location.lat as dropoff_lat,
sp.location.lng as dropoff_lng
FROM
driver.fact_earner_supply_minute as sp
WHERE
sp.datestr = '{datestr}'
and LOWER(sp.flow_type) IN ('uberx', 'p2p')
and sp.city_id = {city_id}
and LOWER(sp.earner_state)='open'
and substr(sp.earner_uuid, 1, length('3')) = '3'
)

SELECT
    datestr,
    city_id, 
    trip_uuid,
    session_uuid,
    driver_uuid,
    local_time,
    eta,
    fare,
    duration_min,
    --   Driver origin information
    driver_origin_lat,
    driver_origin_lng,
    pickup_lat,
    pickup_lng,
    dropoff_lat,
    dropoff_lng
FROM
    completed
UNION
SELECT
    datestr,
    city_id,
    trip_uuid,
    session_uuid,
    driver_uuid,
    local_time,
    eta,
    fare,
    duration_min,
    driver_origin_lat,
    driver_origin_lng,
    pickup_lat,
    pickup_lng,
    dropoff_lat,
    dropoff_lng
FROM
    idle


"""

In [8]:
# city_id, num_days, datestr

@dataclass
class Query:
    prefix: str
    city_id: int
    datestr: str
    num_days: int
    
    def __post_init__(self):
        self.name = f'{self.prefix}_city{self.city_id}_{self.datestr}'
        self.qry = QUERY.format(city_id=self.city_id, datestr=self.datestr)
        
class MyDataFetcher(DataFetcher):
    def query_many_presto(self, *args, **kwargs):
        return super().query_many_presto(*args, **kwargs)        

In [9]:

queries = [Query(prefix=prefix, city_id=city_id, datestr=datestr, num_days=1) for datestr in datestrs]

cache_qry_map = {q.name: q.qry for q in queries}

cdf = CachedDataFetcher(
    data_fetcher=MyDataFetcher(
        user_email=USER_EMAIL,
        consumer_name=CONSUMER_NAME,
    ),
    cache_qry_map=cache_qry_map,
    datacenter='phx2',
    datasource='hive-secure',
)

cdf.fetch(bust_cache=False)

Loaded 25/25 dataframes from cache!


In [10]:
scans = pd.concat(cdf.dfs.values(), axis=0, ignore_index=True) 


In [11]:
cols = scans.columns
cols = [i.split('.')[1] for i in cols]
scans.columns = cols

In [12]:
scans['duration_min'][scans['eta'] == 0.0] = idle_duration_seconds / 60

In [13]:
# Calculate new objective function
def clean_df(df):
    df = df[df['fare'].notnull()]
    df['driver_origin_geohash8'] = scans.h3.geo_to_h3(8, lat_col = 'driver_origin_lat', lng_col = 'driver_origin_lng').index
    df['pickup_geohash8'] = scans.h3.geo_to_h3(8, lat_col = 'pickup_lat', lng_col = 'pickup_lng').index
    df['dropoff_geohash8'] = scans.h3.geo_to_h3(8, lat_col = 'dropoff_lat', lng_col = 'dropoff_lng').index
    df['driver_origin_geohash7'] = scans.h3.geo_to_h3(7, lat_col = 'driver_origin_lat', lng_col = 'driver_origin_lng').index
    df['pickup_geohash7'] = scans.h3.geo_to_h3(7, lat_col = 'pickup_lat', lng_col = 'pickup_lng').index
    df['dropoff_geohash7'] = scans.h3.geo_to_h3(7, lat_col = 'dropoff_lat', lng_col = 'dropoff_lng').index    
    df = df[df['driver_origin_lng'] < MAX_LNG]
    df = df[df['driver_origin_lng'] > MIN_LNG]
    df = df[df['driver_origin_lat'] < MAX_LAT]
    df = df[df['driver_origin_lat'] > MIN_LAT]
    df['local_time'] = pd.to_datetime(df.local_time)
    df['weekday_origin'] = df.local_time.dt.dayofweek
    df['weekday_dropoff'] = df.local_time.dt.dayofweek
    df['second_in_day'] = df.local_time.dt.hour * 3600 + \
                          df.local_time.dt.minute * 60 + \
                          df.local_time.dt.second
    df['trip_duration_seconds'] = df.duration_min * 60
    df['total_driver_trip_time'] = df.trip_duration_seconds + df.eta
    df['destination_arrival_time'] = df.total_driver_trip_time + df.second_in_day
    df['destination_arrival_time'][df['eta'] == 0] = df['second_in_day']
    mask = df['destination_arrival_time'] > 24 * 3600
    df['destination_arrival_time'] = df['destination_arrival_time'].mod(24 * 3600)
    df['weekday_dropoff'][mask] = (df['weekday_dropoff'][mask] + 1) % 7
#     df['trip_length'][df['trip_length'] <= 100] = 100
#     df = df.drop_duplicates(subset=['job_uuid', 'supply_uuid'])
#     df = df.dropna()
    return df


In [14]:
df = clean_df(scans)

INFO:jaeger_tracing:Tracing sampler started with sampling refresh interval 60 sec


In [15]:
geohashes8 = list(set(pd.concat([df['pickup_geohash8'], df['dropoff_geohash8']])))
geohashes8 += ['UNK']

geohash8_to_int = {geohashes8[i]: i for i in range(len(geohashes8))}

df['driver_origin_geohash8'] = df['driver_origin_geohash8'].apply(lambda x: geohash8_to_int.get(x, len(geohashes8) - 1))
df['pickup_geohash8'] = df['pickup_geohash8'].apply(lambda x: geohash8_to_int.get(x, len(geohashes8) - 1))
df['dropoff_geohash8'] = df['dropoff_geohash8'].apply(lambda x: geohash8_to_int.get(x, len(geohashes8) - 1))

print(f'number of unique geohash8: {len(geohashes8)}')

with open(f'geohash8Map_{city_id}.pkl', 'wb') as file:
    pickle.dump(geohash8_to_int, file)

number of unique geohash8: 10057


In [16]:
geohashes7 = list(set(pd.concat([df['pickup_geohash7'], df['dropoff_geohash7']])))
geohashes7 += ['UNK']

geohash7_to_int = {geohashes7[i]: i for i in range(len(geohashes7))}

df['driver_origin_geohash7'] = df['driver_origin_geohash7'].apply(lambda x: geohash7_to_int.get(x, len(geohashes7) - 1))
df['pickup_geohash7'] = df['pickup_geohash7'].apply(lambda x: geohash7_to_int.get(x, len(geohashes7) - 1))
df['dropoff_geohash7'] = df['dropoff_geohash7'].apply(lambda x: geohash7_to_int.get(x, len(geohashes7) - 1))

print(f'number of unique geohash7: {len(geohashes7)}')

with open(f'geohash7Map_{city_id}.pkl', 'wb') as file:
    pickle.dump(geohash7_to_int, file)

number of unique geohash7: 3119


In [17]:
df.head()

Unnamed: 0,datestr,city_id,trip_uuid,session_uuid,driver_uuid,local_time,eta,fare,duration_min,driver_origin_lat,driver_origin_lng,pickup_lat,pickup_lng,dropoff_lat,dropoff_lng,driver_origin_geohash8,pickup_geohash8,dropoff_geohash8,driver_origin_geohash7,pickup_geohash7,dropoff_geohash7,weekday_origin,weekday_dropoff,second_in_day,trip_duration_seconds,total_driver_trip_time,destination_arrival_time
0,2022-11-13,20,,,3d5dd9b0-a590-4c7a-96d6-772a0459279d,2022-11-13 08:31:00,0.0,0.0,0.540374,39.947678,-75.150772,39.947678,-75.150772,39.947678,-75.150772,7192,7192,7192,1679,1679,1679,6,6,30660,32.422459,32.422459,30660.0
1,2022-11-13,20,,,3c2f73fb-2f96-43ce-ba11-0e092f5d9040,2022-11-12 20:54:00,0.0,0.0,0.540374,39.96062,-75.137718,39.96062,-75.137718,39.96062,-75.137718,2986,2986,2986,1609,1609,1609,5,5,75240,32.422459,32.422459,75240.0
2,2022-11-13,20,9f0163ce-e7d0-4ad4-a78d-6bf69a993788,5cf77049-e3a3-42c5-817a-7dc9826cd313,59e76dd3-4d22-4ca3-ae8d-8939f66cd655,2022-11-13 16:33:55,316.0,12.99,11.116667,39.97491,-75.24835,39.97379,-75.2723,39.99556,-75.23424,5130,2304,5480,1586,2804,1050,6,6,59635,667.0,983.0,60618.0
3,2022-11-13,20,,,30f459d2-2247-4b74-b806-f3844c73bd58,2022-11-13 06:56:00,0.0,0.0,0.540374,39.903767,-75.197037,39.903767,-75.197037,39.903767,-75.197037,8066,8066,8066,1687,1687,1687,6,6,24960,32.422459,32.422459,24960.0
4,2022-11-13,20,,,35a64076-da8e-40fe-861f-18a44899faab,2022-11-13 11:30:00,0.0,0.0,0.540374,39.885044,-75.242429,39.885044,-75.242429,39.885044,-75.242429,7210,7210,7210,653,653,653,6,6,41400,32.422459,32.422459,41400.0


In [18]:
len(df[df.weekday_dropoff != df.weekday_origin])

22143

### training offline DQN

In [19]:
from mini_sim.util import *
from mini_sim.DQN_offlineData import *
import torch
from torch import Tensor, nn
from torch.optim import Adam, Optimizer
import torch.nn.functional as F
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import DataLoader
from collections import defaultdict, deque
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from typing import Iterator, List, Tuple


In [20]:
class batchMaker:
    def __init__(self, table) -> None:
        self.table = table
        self.lenDF = len(table)
    
    def __len__(self) -> None:
        return len(self.table)
    
    def sample(self, batch_size: int) -> Tuple:
        sample = self.table.iloc[np.random.choice(self.lenDF, batch_size)]
        batch_v1 = torch.zeros(batch_size, 6)
        batch_v2 = torch.zeros(batch_size, 6)
        batch_eta = torch.zeros(batch_size, 1)
        batch_tripDuration = torch.zeros(batch_size, 1)
        batch_fares = torch.zeros(batch_size, 1)

        ### make samples based on en routes
        batch_v1[:, 0] = torch.tensor(sample.weekday_origin.to_numpy(), dtype=torch.float)
        batch_v1[:, 1] = torch.tensor(sample.second_in_day.to_numpy(), dtype=torch.float)
        batch_v1[:, 2] = torch.tensor(sample.driver_origin_geohash7.to_numpy(), dtype=torch.float)
        batch_v1[:, 3] = torch.tensor(sample.driver_origin_geohash8.to_numpy(), dtype=torch.float)
        batch_v1[:, 4] = torch.tensor(sample.driver_origin_lat.to_numpy(), dtype=torch.float)
        batch_v1[:, 5] = torch.tensor(sample.driver_origin_lng.to_numpy(), dtype=torch.float)
        batch_eta[:,0] = torch.tensor(sample.eta.to_numpy(), dtype=torch.float)
        batch_tripDuration[:,0] = torch.tensor(sample.trip_duration_seconds.to_numpy(), dtype=torch.float)
        batch_fares[:,0] = torch.tensor(sample.fare.to_numpy(), dtype=torch.float)
        batch_v2[:, 0] = torch.tensor(sample.weekday_dropoff.to_numpy(), dtype=torch.float)
        batch_v2[:, 1] = torch.tensor(sample.destination_arrival_time.to_numpy() , dtype=torch.float)
        batch_v2[:, 2] = torch.tensor(sample.dropoff_geohash7.to_numpy(), dtype=torch.float)
        batch_v2[:, 3] = torch.tensor(sample.dropoff_geohash8.to_numpy(), dtype=torch.float)
        batch_v2[:, 4] = torch.tensor(sample.dropoff_lat.to_numpy(), dtype=torch.float)
        batch_v2[:, 5] = torch.tensor(sample.dropoff_lng.to_numpy(), dtype=torch.float)

        return (batch_v1, batch_v2, batch_eta, batch_tripDuration, batch_fares)


In [21]:
class RLDataset(IterableDataset):
    """Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.

    Args:
        buffer: replay buffer
        sample_size: number of experiences to sample at a time
    """

    def __init__(self, batch_maker, sample_size: int = 200) -> None:
        self.sample_size = sample_size
        self.batch_maker = batch_maker

    def __iter__(self) -> Iterator[Tuple]:
        batch_v1, batch_v2, batch_eta, batch_tripDuration, batch_fares = self.batch_maker.sample(self.sample_size)
        for i in range(len(batch_eta)):
            yield batch_v1[i], batch_v2[i], batch_eta[i], batch_tripDuration[i], batch_fares[i]

In [22]:
class DQNLightning(LightningModule):
    """Basic DQN Model."""

    def __init__(
        self,
        batch_size: int = 32,
        lr: float = 1e-4,
        gamma: float = DISCOUNT,
        sync_rate: int = 200,
    ) -> None:
        """
        Args:
            batch_size: size of the batches")
            lr: learning rate
            env: gym environment tag
            gamma: discount factor
            sync_rate: how many frames do we update the target network
        """
        super().__init__()
        self.batch_maker = batchMaker(df)
        
        self.gamma = gamma
        self.batch_size = batch_size
        self.lr = lr
        self.sync_rate = sync_rate
        
        self.net = Net(len(geohashes7), len(geohashes8))
        self.target_net = Net(len(geohashes7), len(geohashes8))

    def forward(self, x: Tensor) -> Tensor:
        """Passes in a state x through the network and gets the q_values of each action as an output.

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def dqn_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        """Calculates the mse loss using a mini batch from the replay buffer.

        Args:
            batch: current mini batch of replay data

        Returns:
            loss
        """
        states, next_states, batch_eta, batch_tripDuration, batch_fares = batch        

        state_values, variance = self.net(states)

        with torch.no_grad():
            next_state_values, _ = self.target_net(next_states)
            next_state_values = next_state_values.detach()

        timeNextState = batch_eta + batch_tripDuration
        discountedNextState = torch.pow(self.gamma, timeNextState) * next_state_values
        timeFareCollected = batch_eta
        discountedFare = torch.pow(self.gamma, timeFareCollected) * batch_fares
        expected_state_values = discountedFare + discountedNextState            

        return nn.GaussianNLLLoss()(state_values, expected_state_values, variance)

    def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch recieved.

        Args:
            batch: current mini batch of replay data
            nb_batch: batch number

        Returns:
            Training loss and log metrics
        """
        device = self.get_device(batch)

        # calculates training loss
        loss = self.dqn_loss(batch)

        # Soft update of target network
        if self.global_step % self.sync_rate == 0:
            self.target_net.load_state_dict(self.net.state_dict())

        self.log_dict(
            {
                "train_loss": loss,
            }
        )
        self.log("steps", self.global_step, logger=False, prog_bar=True)

        return loss

    def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer."""
        optimizer = torch.optim.RMSprop(self.net.parameters(), lr=self.lr)
        return optimizer

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = RLDataset(self.batch_maker, self.batch_size)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self.__dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch."""
        return batch[0].device.index if self.on_gpu else "cpu"

In [None]:
model = DQNLightning()

trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=100000,
    val_check_interval=100,
    logger=CSVLogger(save_dir=f"logs/city_id{city_id}"),
)

trainer.fit(model)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type | Params
------------------------------------
0 | net        | Net  | 1.6 M 
1 | target_net | Net  | 1.6 M 
------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.654    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [None]:
!ls logs/city_id20/lightning_logs/version_9/checkpoints


In [None]:
PATH = 'logs/city_id20/lightning_logs/version_9/checkpoints/epoch=99999-step=100000.ckpt'

In [None]:
def get_target_dict(path):
    trained_dict = torch.load(path)['state_dict']
    target_dict = {}
    for k,v in trained_dict.items():
        if k.startswith('target'):
            target_dict[k[11:]] = v
    return target_dict

In [None]:
# model evaluation

tr1 = DQNLightning()
LTSV = Net(len(geohashes7), len(geohashes8))
target_dict = get_target_dict(PATH)
LTSV.load_state_dict(target_dict)


In [None]:

def plotMap(tr, n = 5000, weekday = 2, stat = 'mean'):
    LTSV.eval()
    data = np.zeros((24 * n, 4))
    batch_v1, _, _, _, _ = tr.batch_maker.sample(n)
    batch_v1[:, 0] = weekday
    data[:, 1:3] = batch_v1[:, -2:].repeat(24, 1)
    with torch.no_grad():
        for idx, selectedTime in enumerate(range(0, 24 * 3600, 3600)):
            data[idx * n: (idx + 1) * n, 0] = idx + 1
            batch_v1[:, 1] = selectedTime
            if stat == 'mean':
                out, _ = LTSV(batch_v1)
            elif stat == 'zscore':
                mean, out = LTSV(batch_v1)
                out = mean / torch.sqrt(out)
            elif stat == 'cov':
                mean, out = LTSV(batch_v1)
                out = torch.sqrt(out) / mean
                out = out
            else:
                raise Exception('stat invalid')
            data[idx * n: (idx + 1) * n, -1] = out[:,0]

    hours_df = pd.DataFrame(data = data, columns = ['hour', 'lat', 'lng', 'value'])

    fig = ff.create_hexbin_mapbox(
        data_frame=hours_df, 
        lat="lat",
        lon="lng",
        color='value',
        animation_frame='hour',
        nx_hexagon=100, opacity=0.8,
        min_count=20
    )    

    fig.update_layout(mapbox_style="carto-positron", mapbox_zoom=8, mapbox_center = {"lat": LAT_CENTER, "lon": LON_CENTER},)
    fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})

    fig.update_layout(autosize=False,width=700,height=700)
    fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 600
    fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 600
    fig.layout.coloraxis.showscale = True   
    fig.layout.sliders[0].pad.t = 10
    fig.layout.updatemenus[0].pad.t= 10             

    fig.show()
    

In [None]:
plotMap(tr1, n = 100000, weekday = 4, stat = 'mean')

In [None]:
# plotMap(tr1, n = 100000, weekday = 4, stat = 'cov')

In [None]:
# (1am, x1, y1) --> (2am, x2, y2) w/ $10
# (4am, x1, y1) --> (4:30am, x3, y3) w/ $8

