In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import utils
from models.trajfm import TrajFM
from data import TrajFMDataset, PretrainPadder, fetch_task_padder, X_COL, Y_COL, coord_transform_GPS_UTM
from torch.utils.data import random_split
import warnings
from pipeline import train_user_model, test_user_model
import torch.multiprocessing as mp
warnings.filterwarnings('ignore')

In [2]:
SETTINGS_CACHE_DIR = os.environ.get('SETTINGS_CACHE_DIR', os.path.join('settings', 'cache'))
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', 'saved_model')
LOG_SAVE_DIR = os.environ.get('LOG_SAVE_DIR', 'logs')
PRED_SAVE_DIR = os.environ.get('PRED_SAVE_DIR', 'predictions')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'


if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')
device = f'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# This key is an indicator of multiple things.
datetime_key = utils.get_datetime_key()


with open(os.path.join('settings', f'local_test.json'), 'r') as fp:
    setting = json.load(fp)
    setting = setting[0]
utils.create_if_noexists(SETTINGS_CACHE_DIR)
with open(os.path.join(SETTINGS_CACHE_DIR, f'{datetime_key}.json'), 'w') as fp:
    json.dump(setting, fp)
    
print("device:", device)

device: cuda


In [3]:
# pyright: ignore[reportIndexIssue]
SAVE_NAME = setting["save_name"]

train_traj_df = pd.read_hdf(setting['dataset']['train_traj_df'], key='trips')
print("dataset:", setting['dataset']['train_traj_df'])
user_count = len(train_traj_df['user_id'].unique())
traj_count = len(train_traj_df['traj_id'].unique())
traj_len = len(train_traj_df['seq_i'])
tao = train_traj_df['delta_t'].mean()
setting["finetune"]["padder"]["params"]["num_users"] = user_count

scale = 4000
# if "chengdu" in setting['dataset']['train_traj_df']:
#     UTM_region = 48
# if "xian" in setting['dataset']['train_traj_df']:
#     UTM_region = 49
if "geolife" in setting['dataset']['train_traj_df']:
    UTM_region = 50  
train_dataset = TrajFMDataset(traj_df=train_traj_df, UTM_region=UTM_region, scale = scale)

poi_df = pd.read_hdf(setting['dataset']['poi_df'], key='pois')
poi_embed = torch.from_numpy(np.load(setting['dataset']['poi_embed'])).float().to(device)

poi_coors = poi_df[[X_COL, Y_COL]].to_numpy()
poi_coors = (coord_transform_GPS_UTM(poi_coors, UTM_region) - train_dataset.spatial_middle_coord) / scale
poi_coors = torch.tensor(poi_coors).float().to(device)

# Build the learnable model.
trajfm = TrajFM(poi_embed=poi_embed, 
                poi_coors=poi_coors, 
                UTM_region=UTM_region,
                spatial_middle_coord = train_dataset.spatial_middle_coord, 
                scale = scale, 
                **setting['trajfm'],
                user = user_count).to(device)

dataset: ./dataset/geolife_U56_TrajAll_L1000.h5


In [4]:
data_summary = {
    "users": user_count,
    "total_traj": traj_count,           
    "total_points": traj_len,
    "avg_traj_len": f"{round(tao/3600, 2)} hours",
    "Data Filtering": [ 
        "25th to 75th quartile based on traj_len", 
        "traj_len > 30 points", 
        "delta_t > 1800s",
        "traj/user > 35 traj",
        "resampled traj to 1000 points max",
        "user_number and seq_i recalculated",]
}

for key, value in data_summary.items():
    print(key, ":", value)

users : 56
total_traj : 7602
total_points : 5450141
avg_traj_len : 1.62 hours
Data Filtering : ['25th to 75th quartile based on traj_len', 'traj_len > 30 points', 'delta_t > 1800s', 'traj/user > 35 traj', 'resampled traj to 1000 points max', 'user_number and seq_i recalculated']


In [5]:
total_size = len(train_dataset)
train_size = int(0.6 * total_size)
val_test_size = total_size - train_size
val_size = int(0.5 * val_test_size)
test_size = val_test_size - val_size

train_dataset, val_test_dataset = random_split(train_dataset, [train_size, val_test_size])
val_dataset, test_dataset = random_split(val_test_dataset, [val_size, test_size])

In [6]:
downstreamtask = setting['finetune']['padder']['name']    
padder = fetch_task_padder(padder_name=setting['finetune']['padder']['name'], padder_params=setting['finetune']['padder']['params'])

train_dataloader = DataLoader(train_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
val_dataloader = DataLoader(val_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
test_dataloader = DataLoader(test_dataset, collate_fn=padder, **setting['finetune']['dataloader'])

In [None]:
file_path = f"saved_model/{SAVE_NAME}.{downstreamtask}"
if os.path.exists(file_path):
    print(f"Loading model {file_path}")
    trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))
else:
    print("Model not found, starting new")

train_log, saved_model_state_dict = train_user_model(model=trajfm, 
                                                    train_dataloader=train_dataloader, 
                                                    val_dataloader=val_dataloader,
                                                    device = device, 
                                                    **setting['finetune']['config'],
                                                    data_summary = data_summary)

if setting['finetune'].get('save', False):
    # save model
    utils.create_if_noexists(MODEL_CACHE_DIR)
    torch.save(saved_model_state_dict, os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'))
    
    # save log
    log_dir = os.path.join(LOG_SAVE_DIR, SAVE_NAME)
    utils.create_if_noexists(log_dir)
    log_path = os.path.join(log_dir, f'{SAVE_NAME}_{downstreamtask}.csv')
    file_exists = os.path.exists(log_path)
    train_log.to_csv(log_path, mode='a', header=not file_exists, index=False)

Model not found, starting new


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meuj01[0m ([33mSP_001[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The run id is U56_TrajAll_L1000_v5.2


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.75it/s]
Training, avg loss: 3.668:   2%|▏         | 1/60 [01:19<1:17:42, 79.02s/it]

ACC@1: 18.36%,
ACC@5: 38.48%,
Macro-R: 6.74%,
Macro-P: 1.93%,
Macro-F1: 2.9%,
val_loss 3.526


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.75it/s].02s/it]
Training, avg loss: 3.454:   3%|▎         | 2/60 [02:36<1:15:43, 78.34s/it]

ACC@1: 18.36%,
ACC@5: 39.52%,
Macro-R: 5.62%,
Macro-P: 1.62%,
Macro-F1: 2.26%,
val_loss 3.455


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.76it/s].34s/it]
Training, avg loss: 3.355:   5%|▌         | 3/60 [03:54<1:14:00, 77.90s/it]

ACC@1: 19.92%,
ACC@5: 40.49%,
Macro-R: 6.68%,
Macro-P: 2.1%,
Macro-F1: 2.94%,
val_loss 3.342


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].90s/it]
Training, avg loss: 3.275:   7%|▋         | 4/60 [05:12<1:12:47, 77.99s/it]

ACC@1: 21.74%,
ACC@5: 43.1%,
Macro-R: 8.03%,
Macro-P: 3.22%,
Macro-F1: 3.91%,
val_loss 3.244


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s].99s/it]
Training, avg loss: 3.224:   8%|▊         | 5/60 [06:30<1:11:29, 78.00s/it]

ACC@1: 22.59%,
ACC@5: 43.55%,
Macro-R: 8.51%,
Macro-P: 3.68%,
Macro-F1: 4.64%,
val_loss 3.215


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s].00s/it]
Training, avg loss: 3.175:  10%|█         | 6/60 [07:48<1:10:20, 78.16s/it]

ACC@1: 19.73%,
ACC@5: 42.58%,
Macro-R: 7.7%,
Macro-P: 3.4%,
Macro-F1: 4.33%,
val_loss 3.292




In [None]:
metrics, _ = test_user_model(model=trajfm, dataloader=test_dataloader, device = device)
for key, value in metrics.items():
    print(f"{key}: {round(value * 100, 2)}%,")

In [9]:
df = pd.DataFrame([{
    "Model": f"{SAVE_NAME}",
    **{key: round(value * 100, 2) for key, value in metrics.items()}
}])

csv_path = "logs/test.csv"
if os.path.exists(csv_path):
    df.to_csv(csv_path, mode='a', header=False, index=False)
else:
    df.to_csv(csv_path, index=False)