In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import utils
from models.trajfm import TrajFM
from data import TrajFMDataset, PretrainPadder, fetch_task_padder, X_COL, Y_COL, coord_transform_GPS_UTM
from torch.utils.data import random_split
import warnings
from pipeline import train_user_model, test_user_model
import torch.multiprocessing as mp
warnings.filterwarnings('ignore')

In [2]:
SETTINGS_CACHE_DIR = os.environ.get('SETTINGS_CACHE_DIR', os.path.join('settings', 'cache'))
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', 'saved_model')
LOG_SAVE_DIR = os.environ.get('LOG_SAVE_DIR', 'logs')
PRED_SAVE_DIR = os.environ.get('PRED_SAVE_DIR', 'predictions')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'


if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')
device = f'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# This key is an indicator of multiple things.
datetime_key = utils.get_datetime_key()


with open(os.path.join('settings', f'local_test.json'), 'r') as fp:
    setting = json.load(fp)
    setting = setting[0]
utils.create_if_noexists(SETTINGS_CACHE_DIR)
with open(os.path.join(SETTINGS_CACHE_DIR, f'{datetime_key}.json'), 'w') as fp:
    json.dump(setting, fp)
    
print("device:", device)

device: cuda


In [3]:
# pyright: ignore[reportIndexIssue]
SAVE_NAME = setting["save_name"]

train_traj_df = pd.read_hdf(setting['dataset']['train_traj_df'], key='trips')
print("dataset:", setting['dataset']['train_traj_df'])
user_count = len(train_traj_df['user_id'].unique())
traj_count = len(train_traj_df['traj_id'].unique())
traj_len = len(train_traj_df['seq_i'])
tao = train_traj_df['delta_t'].mean()
setting["finetune"]["padder"]["params"]["num_users"] = user_count

scale = 4000
# if "chengdu" in setting['dataset']['train_traj_df']:
#     UTM_region = 48
# if "xian" in setting['dataset']['train_traj_df']:
#     UTM_region = 49
if "geolife" in setting['dataset']['train_traj_df']:
    UTM_region = 50  
train_dataset = TrajFMDataset(traj_df=train_traj_df, UTM_region=UTM_region, scale = scale)

poi_df = pd.read_hdf(setting['dataset']['poi_df'], key='pois')
poi_embed = torch.from_numpy(np.load(setting['dataset']['poi_embed'])).float().to(device)

poi_coors = poi_df[[X_COL, Y_COL]].to_numpy()
poi_coors = (coord_transform_GPS_UTM(poi_coors, UTM_region) - train_dataset.spatial_middle_coord) / scale
poi_coors = torch.tensor(poi_coors).float().to(device)

# Build the learnable model.
trajfm = TrajFM(poi_embed=poi_embed, 
                poi_coors=poi_coors, 
                UTM_region=UTM_region,
                spatial_middle_coord = train_dataset.spatial_middle_coord, 
                scale = scale, 
                **setting['trajfm'],
                user = user_count).to(device)

dataset: ./dataset/geolife_U89_TrajAll_L1000.h5


In [4]:
data_summary = {
    "users": user_count,
    "total_traj": traj_count,           
    "total_points": traj_len,
    "Data Filtering": [ 
        "25th to 75th quartile based on traj_len", 
        "traj_len > 30 points", 
        "delta_t > 1800s",
        "traj/user > 30 traj",
        "resampled traj to 1000 points max",
        "user_number and seq_i recalculated",
        "only transformer used"]
}

for key, value in data_summary.items():
    print(key, ":", value)

users : 89
total_traj : 8317
total_points : 5941128
Data Filtering : ['25th to 75th quartile based on traj_len', 'traj_len > 30 points', 'delta_t > 1800s', 'traj/user > 30 traj', 'resampled traj to 1000 points max', 'user_number and seq_i recalculated', 'only transformer used']


In [5]:
total_size = len(train_dataset)
train_size = int(0.6 * total_size)
val_test_size = total_size - train_size
val_size = int(0.5 * val_test_size)
test_size = val_test_size - val_size

train_dataset, val_test_dataset = random_split(train_dataset, [train_size, val_test_size])
val_dataset, test_dataset = random_split(val_test_dataset, [val_size, test_size])

In [6]:
downstreamtask = setting['finetune']['padder']['name']

file_path = f"saved_model/{SAVE_NAME}.{downstreamtask}"
if os.path.exists(file_path):
    print(f"Loading model {file_path}")
    trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))
else:
    print("Model not found, starting new")
    
padder = fetch_task_padder(padder_name=setting['finetune']['padder']['name'], padder_params=setting['finetune']['padder']['params'])

train_dataloader = DataLoader(train_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
val_dataloader = DataLoader(val_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
test_dataloader = DataLoader(test_dataset, collate_fn=padder, **setting['finetune']['dataloader'])

Model not found, starting new


In [7]:
train_log, saved_model_state_dict = train_user_model(model=trajfm, 
                                                    train_dataloader=train_dataloader, 
                                                    val_dataloader=val_dataloader,
                                                    device = device, 
                                                    **setting['finetune']['config'],
                                                    data_summary = data_summary)

if setting['finetune'].get('save', False):
    # save model
    utils.create_if_noexists(MODEL_CACHE_DIR)
    torch.save(saved_model_state_dict, os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'))
    
    # save log
    log_dir = os.path.join(LOG_SAVE_DIR, SAVE_NAME)
    utils.create_if_noexists(log_dir)
    log_path = os.path.join(log_dir, f'{SAVE_NAME}_{downstreamtask}.csv')
    file_exists = os.path.exists(log_path)
    train_log.to_csv(log_path, mode='a', header=not file_exists, index=False)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meuj01[0m ([33mSP_001[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The run id is U89_TrajAll_L1000_v3.7_noROPE


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]
Training, avg loss: 3.907:   1%|▏         | 1/70 [01:17<1:28:48, 77.22s/it]

ACC@1: 18.94%,
ACC@5: 37.28%,
Macro-R: 6.78%,
Macro-P: 1.79%,
Macro-F1: 2.7%,
val_loss 3.755


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s].22s/it]
Training, avg loss: 3.738:   3%|▎         | 2/70 [02:33<1:27:11, 76.94s/it]

ACC@1: 18.82%,
ACC@5: 36.93%,
Macro-R: 6.55%,
Macro-P: 1.76%,
Macro-F1: 2.68%,
val_loss 3.715


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s].94s/it]
Training, avg loss: 3.682:   4%|▍         | 3/70 [03:50<1:25:41, 76.74s/it]

ACC@1: 20.99%,
ACC@5: 38.54%,
Macro-R: 7.9%,
Macro-P: 2.54%,
Macro-F1: 3.36%,
val_loss 3.617


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s].74s/it]
Training, avg loss: 3.577:   6%|▌         | 4/70 [05:06<1:24:15, 76.60s/it]

ACC@1: 20.51%,
ACC@5: 42.58%,
Macro-R: 7.14%,
Macro-P: 2.16%,
Macro-F1: 3.03%,
val_loss 3.524


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s].60s/it]
Training, avg loss: 3.484:   7%|▋         | 5/70 [06:23<1:22:50, 76.47s/it]

ACC@1: 20.93%,
ACC@5: 45.23%,
Macro-R: 8.0%,
Macro-P: 2.9%,
Macro-F1: 3.59%,
val_loss 3.44


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s].47s/it]
Training, avg loss: 3.384:   9%|▊         | 6/70 [07:39<1:21:39, 76.55s/it]

ACC@1: 24.3%,
ACC@5: 47.39%,
Macro-R: 10.06%,
Macro-P: 6.1%,
Macro-F1: 6.28%,
val_loss 3.351


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s].55s/it]
Training, avg loss: 3.317:  10%|█         | 7/70 [08:55<1:20:14, 76.41s/it]

ACC@1: 28.75%,
ACC@5: 51.85%,
Macro-R: 13.99%,
Macro-P: 9.33%,
Macro-F1: 9.79%,
val_loss 3.225


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s].41s/it]
Training, avg loss: 3.102:  11%|█▏        | 8/70 [10:12<1:19:00, 76.45s/it]

ACC@1: 33.79%,
ACC@5: 55.8%,
Macro-R: 18.58%,
Macro-P: 13.72%,
Macro-F1: 14.32%,
val_loss 3.022


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s].45s/it]
Training, avg loss: 2.956:  13%|█▎        | 9/70 [11:28<1:17:44, 76.46s/it]

ACC@1: 36.85%,
ACC@5: 58.09%,
Macro-R: 21.76%,
Macro-P: 16.48%,
Macro-F1: 17.36%,
val_loss 2.905


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s].46s/it]
Training, avg loss: 2.850:  14%|█▍        | 10/70 [12:45<1:16:30, 76.51s/it]

ACC@1: 37.4%,
ACC@5: 59.78%,
Macro-R: 22.36%,
Macro-P: 16.01%,
Macro-F1: 17.06%,
val_loss 2.81


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.51s/it]
Training, avg loss: 2.748:  16%|█▌        | 11/70 [14:01<1:15:10, 76.45s/it]

ACC@1: 39.51%,
ACC@5: 60.08%,
Macro-R: 23.71%,
Macro-P: 18.01%,
Macro-F1: 18.91%,
val_loss 2.77


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.45s/it]
Training, avg loss: 2.668:  17%|█▋        | 12/70 [15:17<1:13:47, 76.34s/it]

ACC@1: 41.25%,
ACC@5: 62.77%,
Macro-R: 25.9%,
Macro-P: 20.37%,
Macro-F1: 21.18%,
val_loss 2.688


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]6.34s/it]
Training, avg loss: 2.584:  19%|█▊        | 13/70 [16:34<1:12:28, 76.30s/it]

ACC@1: 41.32%,
ACC@5: 66.03%,
Macro-R: 26.55%,
Macro-P: 20.93%,
Macro-F1: 21.9%,
val_loss 2.574


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.30s/it]
Training, avg loss: 2.534:  20%|██        | 14/70 [17:50<1:11:17, 76.38s/it]

ACC@1: 43.77%,
ACC@5: 66.32%,
Macro-R: 28.38%,
Macro-P: 21.81%,
Macro-F1: 23.16%,
val_loss 2.518


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]6.38s/it]
Training, avg loss: 2.442:  21%|██▏       | 15/70 [19:07<1:10:07, 76.49s/it]

ACC@1: 46.12%,
ACC@5: 68.42%,
Macro-R: 31.09%,
Macro-P: 24.94%,
Macro-F1: 26.11%,
val_loss 2.47


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.49s/it]
Training, avg loss: 2.388:  23%|██▎       | 16/70 [20:23<1:08:45, 76.40s/it]

ACC@1: 45.46%,
ACC@5: 69.15%,
Macro-R: 30.52%,
Macro-P: 25.56%,
Macro-F1: 26.13%,
val_loss 2.418


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.40s/it]
Training, avg loss: 2.331:  24%|██▍       | 17/70 [21:40<1:07:31, 76.44s/it]

ACC@1: 49.07%,
ACC@5: 69.88%,
Macro-R: 33.51%,
Macro-P: 28.37%,
Macro-F1: 29.19%,
val_loss 2.357


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.44s/it]
Training, avg loss: 2.271:  26%|██▌       | 18/70 [22:56<1:06:13, 76.41s/it]

ACC@1: 47.86%,
ACC@5: 71.19%,
Macro-R: 32.35%,
Macro-P: 28.1%,
Macro-F1: 28.52%,
val_loss 2.317


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.41s/it]
Training, avg loss: 2.224:  27%|██▋       | 19/70 [24:13<1:04:57, 76.43s/it]

ACC@1: 49.37%,
ACC@5: 70.23%,
Macro-R: 33.72%,
Macro-P: 29.62%,
Macro-F1: 30.01%,
val_loss 2.286


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.43s/it]
Training, avg loss: 2.173:  29%|██▊       | 20/70 [25:29<1:03:43, 76.47s/it]

ACC@1: 50.08%,
ACC@5: 72.1%,
Macro-R: 33.84%,
Macro-P: 28.65%,
Macro-F1: 29.54%,
val_loss 2.222


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.84it/s]6.47s/it]
Training, avg loss: 2.117:  30%|███       | 21/70 [26:45<1:02:16, 76.25s/it]

ACC@1: 52.97%,
ACC@5: 74.27%,
Macro-R: 36.75%,
Macro-P: 32.39%,
Macro-F1: 32.97%,
val_loss 2.195


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]6.25s/it]
Training, avg loss: 2.109:  31%|███▏      | 22/70 [28:00<1:00:50, 76.06s/it]

ACC@1: 52.62%,
ACC@5: 74.26%,
Macro-R: 36.26%,
Macro-P: 31.98%,
Macro-F1: 32.51%,
val_loss 2.168


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.06s/it]
Training, avg loss: 2.042:  33%|███▎      | 23/70 [29:16<59:32, 76.00s/it]  

ACC@1: 52.01%,
ACC@5: 75.34%,
Macro-R: 36.91%,
Macro-P: 32.66%,
Macro-F1: 33.31%,
val_loss 2.111


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]00s/it]
Training, avg loss: 2.002:  34%|███▍      | 24/70 [30:32<58:15, 75.98s/it]

ACC@1: 53.64%,
ACC@5: 75.17%,
Macro-R: 38.12%,
Macro-P: 33.7%,
Macro-F1: 34.37%,
val_loss 2.077


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]98s/it]
Training, avg loss: 1.960:  36%|███▌      | 25/70 [31:48<56:56, 75.93s/it]

ACC@1: 52.37%,
ACC@5: 75.22%,
Macro-R: 37.37%,
Macro-P: 33.44%,
Macro-F1: 33.56%,
val_loss 2.065


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.84it/s]93s/it]
Training, avg loss: 1.921:  37%|███▋      | 26/70 [33:04<55:37, 75.84s/it]

ACC@1: 54.18%,
ACC@5: 76.37%,
Macro-R: 38.62%,
Macro-P: 34.34%,
Macro-F1: 34.95%,
val_loss 2.032


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]84s/it]
Training, avg loss: 1.882:  39%|███▊      | 27/70 [34:20<54:21, 75.84s/it]

ACC@1: 54.84%,
ACC@5: 77.94%,
Macro-R: 38.78%,
Macro-P: 35.13%,
Macro-F1: 35.58%,
val_loss 2.013


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]84s/it]
Training, avg loss: 1.856:  40%|████      | 28/70 [35:35<53:03, 75.79s/it]

ACC@1: 54.84%,
ACC@5: 78.77%,
Macro-R: 39.45%,
Macro-P: 34.35%,
Macro-F1: 35.17%,
val_loss 1.99


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]79s/it]
Training, avg loss: 1.826:  41%|████▏     | 29/70 [36:51<51:45, 75.74s/it]

ACC@1: 55.56%,
ACC@5: 79.07%,
Macro-R: 39.75%,
Macro-P: 36.5%,
Macro-F1: 36.73%,
val_loss 1.919


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]74s/it]
Training, avg loss: 1.805:  43%|████▎     | 30/70 [38:06<50:27, 75.68s/it]

ACC@1: 55.81%,
ACC@5: 79.02%,
Macro-R: 40.22%,
Macro-P: 36.05%,
Macro-F1: 36.64%,
val_loss 1.916


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]68s/it]
Training, avg loss: 1.777:  44%|████▍     | 31/70 [39:22<49:13, 75.74s/it]

ACC@1: 57.37%,
ACC@5: 78.83%,
Macro-R: 41.52%,
Macro-P: 38.6%,
Macro-F1: 38.69%,
val_loss 1.888


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.84it/s]74s/it]
Training, avg loss: 1.743:  46%|████▌     | 32/70 [40:38<47:57, 75.72s/it]

ACC@1: 56.71%,
ACC@5: 79.02%,
Macro-R: 41.09%,
Macro-P: 37.57%,
Macro-F1: 37.83%,
val_loss 1.946


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]72s/it]
Training, avg loss: 1.719:  47%|████▋     | 33/70 [41:54<46:41, 75.71s/it]

ACC@1: 56.76%,
ACC@5: 80.81%,
Macro-R: 39.84%,
Macro-P: 37.03%,
Macro-F1: 37.11%,
val_loss 1.872


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]71s/it]
Training, avg loss: 1.693:  49%|████▊     | 34/70 [43:09<45:26, 75.75s/it]

ACC@1: 57.01%,
ACC@5: 80.58%,
Macro-R: 40.2%,
Macro-P: 37.49%,
Macro-F1: 37.55%,
val_loss 1.821


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]75s/it]
Training, avg loss: 1.659:  50%|█████     | 35/70 [44:25<44:09, 75.70s/it]

ACC@1: 56.47%,
ACC@5: 80.39%,
Macro-R: 41.0%,
Macro-P: 37.17%,
Macro-F1: 37.81%,
val_loss 1.831


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]70s/it]
Training, avg loss: 1.646:  51%|█████▏    | 36/70 [45:41<42:54, 75.72s/it]

ACC@1: 57.43%,
ACC@5: 81.66%,
Macro-R: 41.78%,
Macro-P: 39.07%,
Macro-F1: 38.97%,
val_loss 1.823


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]72s/it]
Training, avg loss: 1.623:  53%|█████▎    | 37/70 [46:56<41:37, 75.68s/it]

ACC@1: 58.15%,
ACC@5: 81.54%,
Macro-R: 42.43%,
Macro-P: 39.25%,
Macro-F1: 39.45%,
val_loss 1.772


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]68s/it]
Training, avg loss: 1.605:  54%|█████▍    | 38/70 [48:12<40:22, 75.69s/it]

ACC@1: 56.52%,
ACC@5: 82.14%,
Macro-R: 40.62%,
Macro-P: 37.59%,
Macro-F1: 37.55%,
val_loss 1.765


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]69s/it]
Training, avg loss: 1.573:  56%|█████▌    | 39/70 [49:28<39:06, 75.71s/it]

ACC@1: 59.3%,
ACC@5: 81.72%,
Macro-R: 42.03%,
Macro-P: 39.45%,
Macro-F1: 39.37%,
val_loss 1.728


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]71s/it]
Training, avg loss: 1.551:  57%|█████▋    | 40/70 [50:44<37:52, 75.74s/it]

ACC@1: 59.29%,
ACC@5: 81.9%,
Macro-R: 42.84%,
Macro-P: 39.95%,
Macro-F1: 40.12%,
val_loss 1.729


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]74s/it]
Training, avg loss: 1.536:  59%|█████▊    | 41/70 [52:00<36:38, 75.80s/it]

ACC@1: 59.89%,
ACC@5: 81.6%,
Macro-R: 44.89%,
Macro-P: 40.53%,
Macro-F1: 41.15%,
val_loss 1.696


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]80s/it]
Training, avg loss: 1.512:  60%|██████    | 42/70 [53:15<35:20, 75.73s/it]

ACC@1: 60.19%,
ACC@5: 81.84%,
Macro-R: 42.57%,
Macro-P: 40.23%,
Macro-F1: 40.25%,
val_loss 1.728


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]73s/it]
Training, avg loss: 1.503:  61%|██████▏   | 43/70 [54:31<34:06, 75.79s/it]

ACC@1: 59.96%,
ACC@5: 82.62%,
Macro-R: 43.85%,
Macro-P: 41.45%,
Macro-F1: 41.19%,
val_loss 1.679


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]79s/it]
Training, avg loss: 1.478:  63%|██████▎   | 44/70 [55:47<32:49, 75.75s/it]

ACC@1: 59.89%,
ACC@5: 82.62%,
Macro-R: 44.21%,
Macro-P: 41.79%,
Macro-F1: 41.59%,
val_loss 1.675


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]75s/it]
Training, avg loss: 1.484:  64%|██████▍   | 45/70 [57:03<31:33, 75.75s/it]

ACC@1: 60.02%,
ACC@5: 82.57%,
Macro-R: 45.01%,
Macro-P: 42.66%,
Macro-F1: 42.56%,
val_loss 1.675


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]75s/it]
Training, avg loss: 1.443:  66%|██████▌   | 46/70 [58:18<30:18, 75.77s/it]

ACC@1: 60.32%,
ACC@5: 83.34%,
Macro-R: 44.26%,
Macro-P: 41.26%,
Macro-F1: 41.47%,
val_loss 1.63


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]77s/it]
Training, avg loss: 1.413:  67%|██████▋   | 47/70 [59:34<29:01, 75.72s/it]

ACC@1: 59.66%,
ACC@5: 83.53%,
Macro-R: 44.27%,
Macro-P: 41.58%,
Macro-F1: 41.59%,
val_loss 1.629


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]5.72s/it]
Training, avg loss: 1.399:  69%|██████▊   | 48/70 [1:00:50<27:47, 75.79s/it]

ACC@1: 60.37%,
ACC@5: 83.4%,
Macro-R: 45.48%,
Macro-P: 43.25%,
Macro-F1: 43.03%,
val_loss 1.647


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]5.79s/it]
Training, avg loss: 1.390:  70%|███████   | 49/70 [1:02:05<26:30, 75.71s/it]

ACC@1: 62.36%,
ACC@5: 84.55%,
Macro-R: 47.61%,
Macro-P: 44.69%,
Macro-F1: 44.83%,
val_loss 1.601


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]5.71s/it]
Training, avg loss: 1.370:  71%|███████▏  | 50/70 [1:03:21<25:14, 75.72s/it]

ACC@1: 60.25%,
ACC@5: 83.64%,
Macro-R: 44.39%,
Macro-P: 42.11%,
Macro-F1: 42.11%,
val_loss 1.603


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]5.72s/it]
Training, avg loss: 1.352:  73%|███████▎  | 51/70 [1:04:38<24:02, 75.92s/it]

ACC@1: 62.05%,
ACC@5: 85.2%,
Macro-R: 46.01%,
Macro-P: 43.05%,
Macro-F1: 43.29%,
val_loss 1.598


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]5.92s/it]
Training, avg loss: 1.341:  74%|███████▍  | 52/70 [1:05:54<22:50, 76.11s/it]

ACC@1: 62.36%,
ACC@5: 84.67%,
Macro-R: 46.13%,
Macro-P: 44.54%,
Macro-F1: 44.1%,
val_loss 1.615


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.11s/it]
Training, avg loss: 1.332:  76%|███████▌  | 53/70 [1:07:11<21:37, 76.31s/it]

ACC@1: 62.78%,
ACC@5: 83.83%,
Macro-R: 47.45%,
Macro-P: 45.41%,
Macro-F1: 45.2%,
val_loss 1.588


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.31s/it]
Training, avg loss: 1.303:  77%|███████▋  | 54/70 [1:08:28<20:22, 76.40s/it]

ACC@1: 60.97%,
ACC@5: 84.13%,
Macro-R: 45.42%,
Macro-P: 42.76%,
Macro-F1: 42.74%,
val_loss 1.617


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.40s/it]
Training, avg loss: 1.305:  79%|███████▊  | 55/70 [1:09:44<19:07, 76.51s/it]

ACC@1: 62.41%,
ACC@5: 85.39%,
Macro-R: 45.7%,
Macro-P: 43.5%,
Macro-F1: 43.51%,
val_loss 1.566


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]6.51s/it]
Training, avg loss: 1.336:  80%|████████  | 56/70 [1:11:01<17:51, 76.56s/it]

ACC@1: 62.3%,
ACC@5: 84.67%,
Macro-R: 46.88%,
Macro-P: 44.89%,
Macro-F1: 44.67%,
val_loss 1.557


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.56s/it]
Training, avg loss: 1.282:  81%|████████▏ | 57/70 [1:12:18<16:36, 76.67s/it]

ACC@1: 62.78%,
ACC@5: 84.79%,
Macro-R: 46.27%,
Macro-P: 43.73%,
Macro-F1: 43.74%,
val_loss 1.542


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]6.67s/it]
Training, avg loss: 1.263:  83%|████████▎ | 58/70 [1:13:35<15:21, 76.80s/it]

ACC@1: 63.63%,
ACC@5: 85.33%,
Macro-R: 48.15%,
Macro-P: 45.47%,
Macro-F1: 45.63%,
val_loss 1.509


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.80s/it]
Training, avg loss: 1.238:  84%|████████▍ | 59/70 [1:14:52<14:03, 76.72s/it]

ACC@1: 63.56%,
ACC@5: 85.69%,
Macro-R: 47.92%,
Macro-P: 45.93%,
Macro-F1: 45.66%,
val_loss 1.488


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]6.72s/it]
Training, avg loss: 1.226:  86%|████████▌ | 60/70 [1:16:08<12:46, 76.66s/it]

ACC@1: 61.09%,
ACC@5: 84.48%,
Macro-R: 46.23%,
Macro-P: 43.11%,
Macro-F1: 43.37%,
val_loss 1.538


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.80it/s]6.66s/it]
Training, avg loss: 1.163:  87%|████████▋ | 61/70 [1:17:25<11:29, 76.65s/it]

ACC@1: 64.7%,
ACC@5: 86.4%,
Macro-R: 49.27%,
Macro-P: 47.11%,
Macro-F1: 47.01%,
val_loss 1.465


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.65s/it]
Training, avg loss: 1.153:  89%|████████▊ | 62/70 [1:18:41<10:12, 76.59s/it]

ACC@1: 65.01%,
ACC@5: 86.59%,
Macro-R: 49.54%,
Macro-P: 47.52%,
Macro-F1: 47.37%,
val_loss 1.461


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.59s/it]
Training, avg loss: 1.150:  90%|█████████ | 63/70 [1:19:58<08:56, 76.67s/it]

ACC@1: 65.18%,
ACC@5: 86.59%,
Macro-R: 49.01%,
Macro-P: 47.01%,
Macro-F1: 46.95%,
val_loss 1.46


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.83it/s]6.67s/it]
Training, avg loss: 1.149:  91%|█████████▏| 64/70 [1:21:14<07:39, 76.51s/it]

ACC@1: 64.7%,
ACC@5: 86.47%,
Macro-R: 48.8%,
Macro-P: 46.68%,
Macro-F1: 46.55%,
val_loss 1.461


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.51s/it]
Training, avg loss: 1.148:  93%|█████████▎| 65/70 [1:22:30<06:21, 76.40s/it]

ACC@1: 64.16%,
ACC@5: 86.53%,
Macro-R: 48.41%,
Macro-P: 46.66%,
Macro-F1: 46.38%,
val_loss 1.463


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.40s/it]
Training, avg loss: 1.146:  94%|█████████▍| 66/70 [1:23:47<05:05, 76.39s/it]

ACC@1: 64.53%,
ACC@5: 86.53%,
Macro-R: 48.72%,
Macro-P: 46.59%,
Macro-F1: 46.58%,
val_loss 1.46


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.82it/s]6.39s/it]
Training, avg loss: 1.143:  96%|█████████▌| 67/70 [1:25:03<03:48, 76.33s/it]

ACC@1: 64.76%,
ACC@5: 86.84%,
Macro-R: 49.24%,
Macro-P: 46.46%,
Macro-F1: 46.67%,
val_loss 1.456


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.33s/it]
Training, avg loss: 1.143:  97%|█████████▋| 68/70 [1:26:19<02:32, 76.36s/it]

ACC@1: 65.12%,
ACC@5: 86.77%,
Macro-R: 48.95%,
Macro-P: 47.0%,
Macro-F1: 46.86%,
val_loss 1.456


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.36s/it]
Training, avg loss: 1.140:  99%|█████████▊| 69/70 [1:27:36<01:16, 76.37s/it]

ACC@1: 64.47%,
ACC@5: 86.96%,
Macro-R: 49.26%,
Macro-P: 47.07%,
Macro-F1: 47.01%,
val_loss 1.456


Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]6.37s/it]
Training, avg loss: 1.139: 100%|██████████| 70/70 [1:28:52<00:00, 76.18s/it]

ACC@1: 65.43%,
ACC@5: 86.77%,
Macro-R: 50.0%,
Macro-P: 47.48%,
Macro-F1: 47.56%,
val_loss 1.455





In [8]:
metrics, _ = test_user_model(model=trajfm, dataloader=test_dataloader, device = device)
for key, value in metrics.items():
    print(f"{key}: {round(value * 100, 2)}%,")

Testing/Validating: 100%|██████████| 52/52 [00:28<00:00,  1.81it/s]

ACC@1: 65.26%,
ACC@5: 85.82%,
Macro-R: 50.67%,
Macro-P: 48.78%,
Macro-F1: 48.33%,





In [9]:
df = pd.DataFrame([{
    "Model": f"{SAVE_NAME}",
    **{key: round(value * 100, 2) for key, value in metrics.items()}
}])

csv_path = "logs/test.csv"
if os.path.exists(csv_path):
    df.to_csv(csv_path, mode='a', header=False, index=False)
else:
    df.to_csv(csv_path, index=False)