In [8]:
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import utils
from models.trajfm import TrajFM
from data import TrajFMDataset, PretrainPadder, fetch_task_padder, X_COL, Y_COL, coord_transform_GPS_UTM
from torch.utils.data import random_split
import warnings
from pipeline import train_user_model, test_user_model
import torch.multiprocessing as mp
warnings.filterwarnings('ignore')

In [9]:
SETTINGS_CACHE_DIR = os.environ.get('SETTINGS_CACHE_DIR', os.path.join('settings', 'cache'))
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', 'saved_model')
LOG_SAVE_DIR = os.environ.get('LOG_SAVE_DIR', 'logs')
PRED_SAVE_DIR = os.environ.get('PRED_SAVE_DIR', 'predictions')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'


if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')
device = f'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# This key is an indicator of multiple things.
datetime_key = utils.get_datetime_key()


with open(os.path.join('settings', f'local_test.json'), 'r') as fp:
    setting = json.load(fp)
    setting = setting[0]
utils.create_if_noexists(SETTINGS_CACHE_DIR)
with open(os.path.join(SETTINGS_CACHE_DIR, f'{datetime_key}.json'), 'w') as fp:
    json.dump(setting, fp)
    
print("device:", device)

device: cuda


In [10]:
SAVE_NAME = setting["save_name"]

train_traj_df = pd.read_hdf(setting['dataset']['train_traj_df'], key='trips')
print("dataset:", setting['dataset']['train_traj_df'])
user_count = len(train_traj_df['user_id'].unique())
traj_count = len(train_traj_df['traj_id'].unique())
traj_len = len(train_traj_df['seq_i'])
tao = train_traj_df['delta_t'].mean()
setting["finetune"]["padder"]["params"]["num_users"] = user_count

scale = 4000
# if "chengdu" in setting['dataset']['train_traj_df']:
#     UTM_region = 48
# if "xian" in setting['dataset']['train_traj_df']:
#     UTM_region = 49
if "geolife" in setting['dataset']['train_traj_df']:
    UTM_region = 50  
train_dataset = TrajFMDataset(traj_df=train_traj_df, UTM_region=UTM_region, scale = scale)

poi_df = pd.read_hdf(setting['dataset']['poi_df'], key='pois')
poi_embed = torch.from_numpy(np.load(setting['dataset']['poi_embed'])).float().to(device)

poi_coors = poi_df[[X_COL, Y_COL]].to_numpy()
poi_coors = (coord_transform_GPS_UTM(poi_coors, UTM_region) - train_dataset.spatial_middle_coord) / scale
poi_coors = torch.tensor(poi_coors).float().to(device)

# Build the learnable model.
trajfm = TrajFM(poi_embed=poi_embed, 
                poi_coors=poi_coors, 
                UTM_region=UTM_region,
                spatial_middle_coord = train_dataset.spatial_middle_coord, 
                scale = scale, 
                **setting['trajfm'],
                user = user_count).to(device)

dataset: ./dataset/geolife_U89_TrajAll_L1000.h5


In [11]:
data_summary = {
    "users": user_count,
    "total_traj": traj_count,           
    "total_points": traj_len,
    "Data Filtering": [ 
        "25th to 75th quartile based on traj_len", 
        "traj_len > 30 points", 
        "delta_t > 1800s",
        "traj/user > 30 traj",
        "resampled traj to 1000 points max",
        "user_number and seq_i recalculated",
        "only transformer used"]
}

for key, value in data_summary.items():
    print(key, ":", value)

users : 89
total_traj : 8317
total_points : 5941128
Data Filtering : ['25th to 75th quartile based on traj_len', 'traj_len > 30 points', 'delta_t > 1800s', 'traj/user > 30 traj', 'resampled traj to 1000 points max', 'user_number and seq_i recalculated', 'only transformer used']


In [12]:
total_size = len(train_dataset)
train_size = int(0.6 * total_size)
val_test_size = total_size - train_size
val_size = int(0.5 * val_test_size)
test_size = val_test_size - val_size

train_dataset, val_test_dataset = random_split(train_dataset, [train_size, val_test_size])
val_dataset, test_dataset = random_split(val_test_dataset, [val_size, test_size])

In [13]:
downstreamtask = setting['finetune']['padder']['name']

file_path = f"{SAVE_NAME}.{downstreamtask}"
if os.path.exists(file_path):
    print(f"Loading model {file_path}")
    trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))
else:
    print("Model not found, starting new")
    
padder = fetch_task_padder(padder_name=setting['finetune']['padder']['name'], padder_params=setting['finetune']['padder']['params'])

train_dataloader = DataLoader(train_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
val_dataloader = DataLoader(val_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
test_dataloader = DataLoader(test_dataset, collate_fn=padder, **setting['finetune']['dataloader'])

Model not found, starting new


In [14]:
train_log, saved_model_state_dict = train_user_model(model=trajfm, 
                                                    train_dataloader=train_dataloader, 
                                                    val_dataloader=val_dataloader,
                                                    device = device, 
                                                    **setting['finetune']['config'],
                                                    data_summary = data_summary)

if setting['finetune'].get('save', False):
    # save model
    utils.create_if_noexists(MODEL_CACHE_DIR)
    torch.save(saved_model_state_dict, os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'))
    
    # save log
    log_dir = os.path.join(LOG_SAVE_DIR, SAVE_NAME)
    utils.create_if_noexists(log_dir)
    log_path = os.path.join(log_dir, f'{SAVE_NAME}_{downstreamtask}.csv')
    file_exists = os.path.exists(log_path)
    train_log.to_csv(log_path, mode='a', header=not file_exists, index=False)

The run id is U89_TrajAll_L1000_v3.3_noROPE


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.77it/s]
Training, avg loss: 3.907:   3%|▎         | 1/30 [01:17<37:33, 77.72s/it]

ACC@1: 18.94%,
ACC@5: 37.28%,
Macro-R: 6.78%,
Macro-P: 1.79%,
Macro-F1: 2.7%,
val_loss 3.755


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]2s/it]
Training, avg loss: 3.738:   7%|▋         | 2/30 [02:34<36:02, 77.24s/it]

ACC@1: 18.82%,
ACC@5: 36.93%,
Macro-R: 6.55%,
Macro-P: 1.76%,
Macro-F1: 2.68%,
val_loss 3.715


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]4s/it]
Training, avg loss: 3.682:  10%|█         | 3/30 [03:51<34:36, 76.91s/it]

ACC@1: 20.99%,
ACC@5: 38.54%,
Macro-R: 7.9%,
Macro-P: 2.54%,
Macro-F1: 3.36%,
val_loss 3.617


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]1s/it]
Training, avg loss: 3.577:  13%|█▎        | 4/30 [05:07<33:16, 76.78s/it]

ACC@1: 20.51%,
ACC@5: 42.58%,
Macro-R: 7.14%,
Macro-P: 2.16%,
Macro-F1: 3.03%,
val_loss 3.524


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]8s/it]
Training, avg loss: 3.484:  17%|█▋        | 5/30 [06:24<31:59, 76.76s/it]

ACC@1: 20.93%,
ACC@5: 45.23%,
Macro-R: 8.0%,
Macro-P: 2.9%,
Macro-F1: 3.59%,
val_loss 3.44


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]6s/it]
Training, avg loss: 3.384:  20%|██        | 6/30 [07:41<30:43, 76.80s/it]

ACC@1: 24.3%,
ACC@5: 47.39%,
Macro-R: 10.06%,
Macro-P: 6.1%,
Macro-F1: 6.28%,
val_loss 3.351


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]0s/it]
Training, avg loss: 3.317:  23%|██▎       | 7/30 [08:58<29:25, 76.76s/it]

ACC@1: 28.75%,
ACC@5: 51.85%,
Macro-R: 13.99%,
Macro-P: 9.33%,
Macro-F1: 9.79%,
val_loss 3.225


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]6s/it]
Training, avg loss: 3.102:  27%|██▋       | 8/30 [10:14<28:07, 76.71s/it]

ACC@1: 33.79%,
ACC@5: 55.8%,
Macro-R: 18.58%,
Macro-P: 13.72%,
Macro-F1: 14.32%,
val_loss 3.022


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]1s/it]
Training, avg loss: 2.956:  30%|███       | 9/30 [11:31<26:50, 76.71s/it]

ACC@1: 36.85%,
ACC@5: 58.09%,
Macro-R: 21.76%,
Macro-P: 16.48%,
Macro-F1: 17.36%,
val_loss 2.905


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]1s/it]
Training, avg loss: 2.850:  33%|███▎      | 10/30 [12:48<25:35, 76.77s/it]

ACC@1: 37.4%,
ACC@5: 59.78%,
Macro-R: 22.36%,
Macro-P: 16.01%,
Macro-F1: 17.06%,
val_loss 2.81


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]77s/it]
Training, avg loss: 2.748:  37%|███▋      | 11/30 [14:05<24:20, 76.85s/it]

ACC@1: 39.51%,
ACC@5: 60.08%,
Macro-R: 23.71%,
Macro-P: 18.01%,
Macro-F1: 18.91%,
val_loss 2.77


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.77it/s]85s/it]
Training, avg loss: 2.668:  40%|████      | 12/30 [15:22<23:03, 76.86s/it]

ACC@1: 41.25%,
ACC@5: 62.77%,
Macro-R: 25.9%,
Macro-P: 20.37%,
Macro-F1: 21.18%,
val_loss 2.688


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.79it/s]86s/it]
Training, avg loss: 2.584:  43%|████▎     | 13/30 [16:39<21:47, 76.89s/it]

ACC@1: 41.32%,
ACC@5: 66.03%,
Macro-R: 26.55%,
Macro-P: 20.93%,
Macro-F1: 21.9%,
val_loss 2.574


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]89s/it]
Training, avg loss: 2.534:  47%|████▋     | 14/30 [17:56<20:30, 76.91s/it]

ACC@1: 43.77%,
ACC@5: 66.32%,
Macro-R: 28.38%,
Macro-P: 21.81%,
Macro-F1: 23.16%,
val_loss 2.518


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]91s/it]
Training, avg loss: 2.442:  50%|█████     | 15/30 [19:12<19:13, 76.89s/it]

ACC@1: 46.12%,
ACC@5: 68.42%,
Macro-R: 31.09%,
Macro-P: 24.94%,
Macro-F1: 26.11%,
val_loss 2.47


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]89s/it]
Training, avg loss: 2.388:  53%|█████▎    | 16/30 [20:30<17:57, 76.96s/it]

ACC@1: 45.46%,
ACC@5: 69.15%,
Macro-R: 30.52%,
Macro-P: 25.56%,
Macro-F1: 26.13%,
val_loss 2.418


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]96s/it]
Training, avg loss: 2.331:  57%|█████▋    | 17/30 [21:47<16:40, 76.98s/it]

ACC@1: 49.07%,
ACC@5: 69.88%,
Macro-R: 33.51%,
Macro-P: 28.37%,
Macro-F1: 29.19%,
val_loss 2.357


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]98s/it]
Training, avg loss: 2.271:  60%|██████    | 18/30 [23:04<15:24, 77.01s/it]

ACC@1: 47.86%,
ACC@5: 71.19%,
Macro-R: 32.35%,
Macro-P: 28.1%,
Macro-F1: 28.52%,
val_loss 2.317


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]01s/it]
Training, avg loss: 2.224:  63%|██████▎   | 19/30 [24:21<14:07, 77.04s/it]

ACC@1: 49.37%,
ACC@5: 70.23%,
Macro-R: 33.72%,
Macro-P: 29.62%,
Macro-F1: 30.01%,
val_loss 2.286


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]04s/it]
Training, avg loss: 2.173:  67%|██████▋   | 20/30 [25:38<12:49, 76.97s/it]

ACC@1: 50.08%,
ACC@5: 72.1%,
Macro-R: 33.84%,
Macro-P: 28.65%,
Macro-F1: 29.54%,
val_loss 2.222


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]97s/it]
Training, avg loss: 2.083:  70%|███████   | 21/30 [26:55<11:33, 77.05s/it]

ACC@1: 51.89%,
ACC@5: 73.42%,
Macro-R: 36.3%,
Macro-P: 32.52%,
Macro-F1: 32.74%,
val_loss 2.182


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]05s/it]
Training, avg loss: 2.073:  73%|███████▎  | 22/30 [28:12<10:16, 77.03s/it]

ACC@1: 52.2%,
ACC@5: 73.9%,
Macro-R: 35.77%,
Macro-P: 31.16%,
Macro-F1: 31.85%,
val_loss 2.176


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]03s/it]
Training, avg loss: 2.065:  77%|███████▋  | 23/30 [29:29<09:00, 77.16s/it]

ACC@1: 52.25%,
ACC@5: 74.2%,
Macro-R: 36.68%,
Macro-P: 32.27%,
Macro-F1: 33.0%,
val_loss 2.173


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]16s/it]
Training, avg loss: 2.061:  80%|████████  | 24/30 [30:46<07:43, 77.17s/it]

ACC@1: 52.38%,
ACC@5: 74.21%,
Macro-R: 36.89%,
Macro-P: 32.48%,
Macro-F1: 33.09%,
val_loss 2.171


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]17s/it]
Training, avg loss: 2.055:  83%|████████▎ | 25/30 [32:04<06:26, 77.20s/it]

ACC@1: 53.27%,
ACC@5: 74.56%,
Macro-R: 37.19%,
Macro-P: 32.95%,
Macro-F1: 33.27%,
val_loss 2.165


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]20s/it]
Training, avg loss: 2.049:  87%|████████▋ | 26/30 [33:21<05:08, 77.15s/it]

ACC@1: 52.74%,
ACC@5: 74.51%,
Macro-R: 37.59%,
Macro-P: 32.51%,
Macro-F1: 33.43%,
val_loss 2.159


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]15s/it]
Training, avg loss: 2.042:  90%|█████████ | 27/30 [34:37<03:51, 77.02s/it]

ACC@1: 53.34%,
ACC@5: 74.39%,
Macro-R: 37.34%,
Macro-P: 32.76%,
Macro-F1: 33.54%,
val_loss 2.16


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]02s/it]
Training, avg loss: 2.039:  93%|█████████▎| 28/30 [35:54<02:34, 77.00s/it]

ACC@1: 53.64%,
ACC@5: 74.68%,
Macro-R: 37.78%,
Macro-P: 32.94%,
Macro-F1: 33.64%,
val_loss 2.148


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.78it/s]00s/it]
Training, avg loss: 2.034:  97%|█████████▋| 29/30 [37:11<01:17, 77.01s/it]

ACC@1: 52.97%,
ACC@5: 74.25%,
Macro-R: 37.5%,
Macro-P: 32.79%,
Macro-F1: 33.57%,
val_loss 2.145


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]01s/it]
Training, avg loss: 2.028: 100%|██████████| 30/30 [38:28<00:00, 76.96s/it]

ACC@1: 53.1%,
ACC@5: 74.69%,
Macro-R: 37.13%,
Macro-P: 32.77%,
Macro-F1: 33.44%,
val_loss 2.142





In [15]:
print(f"loading {SAVE_NAME}.{downstreamtask}")
trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))

metrics, _ = test_user_model(model=trajfm, dataloader=test_dataloader, device = device)
for key, value in metrics.items():
    print(f"{key}: {round(value * 100, 2)}%,")

loading U89_TrajAll_L1000_v3.3_noROPE.tul


Testing/Validating: 100%|██████████| 52/52 [00:29<00:00,  1.79it/s]

ACC@1: 50.12%,
ACC@5: 74.7%,
Macro-R: 35.94%,
Macro-P: 30.87%,
Macro-F1: 31.64%,





In [16]:
df = pd.DataFrame([{
    "Model": f"{SAVE_NAME}",
    **{key: round(value * 100, 2) for key, value in metrics.items()}
}])

csv_path = "logs/test.csv"
if os.path.exists(csv_path):
    df.to_csv(csv_path, mode='a', header=False, index=False)
else:
    df.to_csv(csv_path, index=False)