In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import utils
from models.trajfm import TrajFM
from data import TrajFMDataset, fetch_task_padder, X_COL, Y_COL, coord_transform_GPS_UTM
import warnings
from pipeline import train_user_model, test_user_model
import torch.multiprocessing as mp
warnings.filterwarnings('ignore')

In [2]:
SETTINGS_CACHE_DIR = os.environ.get('SETTINGS_CACHE_DIR', os.path.join('settings', 'cache'))
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', 'saved_model')
LOG_SAVE_DIR = os.environ.get('LOG_SAVE_DIR', 'logs')
PRED_SAVE_DIR = os.environ.get('PRED_SAVE_DIR', 'predictions')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'


if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')
device = f'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# This key is an indicator of multiple things.
datetime_key = utils.get_datetime_key()


with open(os.path.join('settings', f'local_test.json'), 'r') as fp:
    setting = json.load(fp)
    setting = setting[0]
utils.create_if_noexists(SETTINGS_CACHE_DIR)
with open(os.path.join(SETTINGS_CACHE_DIR, f'{datetime_key}.json'), 'w') as fp:
    json.dump(setting, fp)
    
print("device:", device)

device: cuda


In [3]:
SAVE_NAME = setting["save_name"]

train_traj_df = pd.read_hdf(setting['dataset']['train_traj_df'], key='trips')
print("dataset:", setting['dataset']['train_traj_df'])
user_count = len(train_traj_df['user_id'].unique())
traj_count = len(train_traj_df['traj_id'].unique())
traj_len = len(train_traj_df['seq_i'])
tao = train_traj_df['delta_t'].mean()
setting["finetune"]["padder"]["params"]["num_users"] = user_count

scale = 4000
if "geolife" in setting['dataset']['train_traj_df']:
    UTM_region = 50  
dataset = TrajFMDataset(traj_df=train_traj_df, UTM_region=UTM_region, scale = scale)

poi_df = pd.read_hdf(setting['dataset']['poi_df'], key='pois')
poi_embed = torch.from_numpy(np.load(setting['dataset']['poi_embed'])).float().to(device)

poi_coors = poi_df[[X_COL, Y_COL]].to_numpy()
poi_coors = (coord_transform_GPS_UTM(poi_coors, UTM_region) - dataset.spatial_middle_coord) / scale
poi_coors = torch.tensor(poi_coors).float().to(device)

# Build the learnable model.
trajfm = TrajFM(poi_embed=poi_embed, 
                poi_coors=poi_coors, 
                UTM_region=UTM_region,
                spatial_middle_coord = dataset.spatial_middle_coord, 
                scale = scale, 
                **setting['trajfm'],
                user = user_count).to(device)

dataset: ./dataset/geolife_U56_TrajAll_L1000.h5


In [4]:
data_summary = {
    "users": user_count,
    "total_traj": traj_count,           
    "total_points": traj_len,
    "avg_traj_len": f"{round(tao/3600, 2)} hours",
    "Data Filtering": [ 
        "25th to 75th quartile based on traj_len", 
        "traj_len > 30 points", 
        "delta_t > 1800s",
        "traj/user > 35 traj",
        "resampled traj to 1000 points max",
        "user_number and seq_i recalculated",
        "stratified",]
}

for key, value in data_summary.items():
    print(key, ":", value)

users : 56
total_traj : 7602
total_points : 5450141
avg_traj_len : 1.62 hours
Data Filtering : ['25th to 75th quartile based on traj_len', 'traj_len > 30 points', 'delta_t > 1800s', 'traj/user > 35 traj', 'resampled traj to 1000 points max', 'user_number and seq_i recalculated', 'stratified']


In [5]:
train_dataset, val_test_dataset = utils.stratify_dataset(dataset = dataset, test_size = 0.4, random_seed = SEED)
val_dataset, test_dataset = utils.stratify_dataset(dataset = val_test_dataset, test_size = 0.5, random_seed = SEED)

In [6]:
# print(len(train_dataset), len(val_dataset), len(test_dataset))

In [7]:
downstreamtask = setting['finetune']['padder']['name']
padder = fetch_task_padder(padder_name=setting['finetune']['padder']['name'], padder_params=setting['finetune']['padder']['params'])

train_dataloader = DataLoader(train_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
val_dataloader = DataLoader(val_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
test_dataloader = DataLoader(test_dataset, collate_fn=padder, **setting['finetune']['dataloader'])

In [8]:
file_path = f"saved_model/{SAVE_NAME}.{downstreamtask}"
if os.path.exists(file_path):
    print(f"Loading model {file_path}")
    trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))
else:
    print("Model not found, starting new")

train_log, saved_model_state_dict = train_user_model(model=trajfm, 
                                                    train_dataloader=train_dataloader, 
                                                    val_dataloader=val_dataloader,
                                                    device = device, 
                                                    **setting['finetune']['config'],
                                                    data_summary = data_summary)

if setting['finetune'].get('save', False):
    # save model
    utils.create_if_noexists(MODEL_CACHE_DIR)
    torch.save(saved_model_state_dict, os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'))
    
    # save log
    log_dir = os.path.join(LOG_SAVE_DIR, SAVE_NAME)
    utils.create_if_noexists(log_dir)
    log_path = os.path.join(log_dir, f'{SAVE_NAME}_{downstreamtask}.csv')
    file_exists = os.path.exists(log_path)
    train_log.to_csv(log_path, mode='a', header=not file_exists, index=False)

Model not found, starting new


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meuj01[0m ([33mSP_001[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The run id is U56_TrajAll_L1000_v1.2s


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]
Training, avg loss: 3.609:   2%|▏         | 1/60 [01:19<1:18:35, 79.93s/it]

ACC@1: 18.95%,
ACC@5: 40.95%,
Macro-R: 7.09%,
Macro-P: 1.94%,
Macro-F1: 2.91%,
val_loss 3.46


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].93s/it]
Training, avg loss: 3.395:   3%|▎         | 2/60 [02:38<1:16:31, 79.16s/it]

ACC@1: 19.21%,
ACC@5: 40.43%,
Macro-R: 6.66%,
Macro-P: 1.84%,
Macro-F1: 2.76%,
val_loss 3.379


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].16s/it]
Training, avg loss: 3.290:   5%|▌         | 3/60 [03:57<1:14:58, 78.93s/it]

ACC@1: 17.77%,
ACC@5: 41.54%,
Macro-R: 5.8%,
Macro-P: 1.62%,
Macro-F1: 2.35%,
val_loss 3.383


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].93s/it]
Training, avg loss: 3.226:   7%|▋         | 4/60 [05:15<1:13:28, 78.72s/it]

ACC@1: 19.27%,
ACC@5: 41.54%,
Macro-R: 7.52%,
Macro-P: 3.11%,
Macro-F1: 3.78%,
val_loss 3.291


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.75it/s].72s/it]
Training, avg loss: 3.158:   8%|▊         | 5/60 [06:34<1:12:04, 78.62s/it]

ACC@1: 21.94%,
ACC@5: 43.23%,
Macro-R: 8.89%,
Macro-P: 3.6%,
Macro-F1: 4.51%,
val_loss 3.188


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.75it/s].62s/it]
Training, avg loss: 3.104:  10%|█         | 6/60 [07:51<1:10:30, 78.34s/it]

ACC@1: 20.96%,
ACC@5: 43.75%,
Macro-R: 8.95%,
Macro-P: 4.42%,
Macro-F1: 5.21%,
val_loss 3.183


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s].34s/it]
Training, avg loss: 3.070:  12%|█▏        | 7/60 [09:11<1:09:29, 78.67s/it]

ACC@1: 22.79%,
ACC@5: 46.74%,
Macro-R: 10.2%,
Macro-P: 5.25%,
Macro-F1: 6.03%,
val_loss 3.092


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s].67s/it]
Training, avg loss: 3.026:  13%|█▎        | 8/60 [10:30<1:08:24, 78.93s/it]

ACC@1: 22.79%,
ACC@5: 48.7%,
Macro-R: 9.39%,
Macro-P: 4.39%,
Macro-F1: 5.39%,
val_loss 3.077


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s].93s/it]
Training, avg loss: 2.971:  15%|█▌        | 9/60 [11:50<1:07:20, 79.23s/it]

ACC@1: 21.48%,
ACC@5: 48.37%,
Macro-R: 9.13%,
Macro-P: 5.15%,
Macro-F1: 5.75%,
val_loss 3.074


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s].23s/it]
Training, avg loss: 2.933:  17%|█▋        | 10/60 [13:10<1:06:09, 79.39s/it]

ACC@1: 19.73%,
ACC@5: 47.66%,
Macro-R: 7.61%,
Macro-P: 4.32%,
Macro-F1: 4.89%,
val_loss 3.088


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.39s/it]
Training, avg loss: 2.861:  18%|█▊        | 11/60 [14:30<1:04:55, 79.50s/it]

ACC@1: 24.28%,
ACC@5: 49.87%,
Macro-R: 11.67%,
Macro-P: 7.47%,
Macro-F1: 8.26%,
val_loss 2.988


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.50s/it]
Training, avg loss: 2.775:  20%|██        | 12/60 [15:49<1:03:41, 79.62s/it]

ACC@1: 25.2%,
ACC@5: 49.93%,
Macro-R: 12.74%,
Macro-P: 9.5%,
Macro-F1: 9.36%,
val_loss 2.934


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.62s/it]
Training, avg loss: 2.669:  22%|██▏       | 13/60 [17:09<1:02:25, 79.68s/it]

ACC@1: 31.05%,
ACC@5: 56.84%,
Macro-R: 17.19%,
Macro-P: 12.28%,
Macro-F1: 13.01%,
val_loss 2.698


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.68s/it]
Training, avg loss: 2.569:  23%|██▎       | 14/60 [18:29<1:01:07, 79.74s/it]

ACC@1: 31.58%,
ACC@5: 57.03%,
Macro-R: 16.63%,
Macro-P: 13.32%,
Macro-F1: 13.64%,
val_loss 2.674


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.74s/it]
Training, avg loss: 2.465:  25%|██▌       | 15/60 [19:49<59:49, 79.77s/it]  

ACC@1: 34.7%,
ACC@5: 61.2%,
Macro-R: 19.78%,
Macro-P: 16.01%,
Macro-F1: 16.22%,
val_loss 2.553


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]77s/it]
Training, avg loss: 2.390:  27%|██▋       | 16/60 [21:09<58:35, 79.89s/it]

ACC@1: 36.39%,
ACC@5: 65.17%,
Macro-R: 21.32%,
Macro-P: 17.32%,
Macro-F1: 17.77%,
val_loss 2.411


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]89s/it]
Training, avg loss: 2.298:  28%|██▊       | 17/60 [22:29<57:16, 79.92s/it]

ACC@1: 35.35%,
ACC@5: 62.11%,
Macro-R: 20.34%,
Macro-P: 16.48%,
Macro-F1: 16.74%,
val_loss 2.468


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]92s/it]
Training, avg loss: 2.226:  30%|███       | 18/60 [23:49<55:57, 79.94s/it]

ACC@1: 38.02%,
ACC@5: 66.15%,
Macro-R: 22.65%,
Macro-P: 18.41%,
Macro-F1: 18.71%,
val_loss 2.364


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]94s/it]
Training, avg loss: 2.185:  32%|███▏      | 19/60 [25:09<54:36, 79.92s/it]

ACC@1: 39.19%,
ACC@5: 68.49%,
Macro-R: 22.91%,
Macro-P: 19.28%,
Macro-F1: 19.5%,
val_loss 2.279


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]92s/it]
Training, avg loss: 2.151:  33%|███▎      | 20/60 [26:29<53:15, 79.88s/it]

ACC@1: 39.78%,
ACC@5: 67.71%,
Macro-R: 23.62%,
Macro-P: 19.36%,
Macro-F1: 19.92%,
val_loss 2.264


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]88s/it]
Training, avg loss: 2.065:  35%|███▌      | 21/60 [27:49<51:57, 79.92s/it]

ACC@1: 40.3%,
ACC@5: 69.27%,
Macro-R: 24.53%,
Macro-P: 20.43%,
Macro-F1: 20.8%,
val_loss 2.25


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]92s/it]
Training, avg loss: 2.024:  37%|███▋      | 22/60 [29:09<50:36, 79.91s/it]

ACC@1: 41.21%,
ACC@5: 67.45%,
Macro-R: 24.47%,
Macro-P: 21.51%,
Macro-F1: 21.57%,
val_loss 2.269


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]91s/it]
Training, avg loss: 1.988:  38%|███▊      | 23/60 [30:29<49:17, 79.93s/it]

ACC@1: 39.97%,
ACC@5: 71.03%,
Macro-R: 23.33%,
Macro-P: 20.21%,
Macro-F1: 20.39%,
val_loss 2.184


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]93s/it]
Training, avg loss: 1.945:  40%|████      | 24/60 [31:49<47:57, 79.93s/it]

ACC@1: 40.56%,
ACC@5: 70.05%,
Macro-R: 25.43%,
Macro-P: 21.81%,
Macro-F1: 21.77%,
val_loss 2.213


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]93s/it]
Training, avg loss: 1.919:  42%|████▏     | 25/60 [33:09<46:37, 79.94s/it]

ACC@1: 42.19%,
ACC@5: 71.55%,
Macro-R: 26.43%,
Macro-P: 22.69%,
Macro-F1: 23.03%,
val_loss 2.129


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]94s/it]
Training, avg loss: 1.902:  43%|████▎     | 26/60 [34:29<45:18, 79.95s/it]

ACC@1: 44.99%,
ACC@5: 73.31%,
Macro-R: 27.32%,
Macro-P: 23.75%,
Macro-F1: 24.13%,
val_loss 2.088


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]95s/it]
Training, avg loss: 1.854:  45%|████▌     | 27/60 [35:49<43:58, 79.95s/it]

ACC@1: 41.99%,
ACC@5: 71.35%,
Macro-R: 26.18%,
Macro-P: 23.56%,
Macro-F1: 23.42%,
val_loss 2.197


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]95s/it]
Training, avg loss: 1.811:  47%|████▋     | 28/60 [37:08<42:35, 79.86s/it]

ACC@1: 42.97%,
ACC@5: 72.14%,
Macro-R: 27.06%,
Macro-P: 24.53%,
Macro-F1: 24.23%,
val_loss 2.103


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]86s/it]
Training, avg loss: 1.759:  48%|████▊     | 29/60 [38:28<41:14, 79.83s/it]

ACC@1: 42.06%,
ACC@5: 71.35%,
Macro-R: 24.92%,
Macro-P: 20.84%,
Macro-F1: 21.46%,
val_loss 2.178


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]83s/it]
Training, avg loss: 1.757:  50%|█████     | 30/60 [39:48<39:53, 79.77s/it]

ACC@1: 47.07%,
ACC@5: 75.13%,
Macro-R: 30.44%,
Macro-P: 27.14%,
Macro-F1: 27.21%,
val_loss 1.973


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]77s/it]
Training, avg loss: 1.704:  52%|█████▏    | 31/60 [41:07<38:32, 79.73s/it]

ACC@1: 46.29%,
ACC@5: 75.65%,
Macro-R: 29.35%,
Macro-P: 26.25%,
Macro-F1: 26.25%,
val_loss 2.005


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]73s/it]
Training, avg loss: 1.673:  53%|█████▎    | 32/60 [42:27<37:12, 79.74s/it]

ACC@1: 44.79%,
ACC@5: 74.28%,
Macro-R: 27.43%,
Macro-P: 25.7%,
Macro-F1: 25.14%,
val_loss 2.036


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]74s/it]
Training, avg loss: 1.637:  55%|█████▌    | 33/60 [43:47<35:52, 79.73s/it]

ACC@1: 44.66%,
ACC@5: 74.15%,
Macro-R: 27.57%,
Macro-P: 25.34%,
Macro-F1: 25.14%,
val_loss 2.091


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]73s/it]
Training, avg loss: 1.607:  57%|█████▋    | 34/60 [45:06<34:32, 79.72s/it]

ACC@1: 45.57%,
ACC@5: 76.17%,
Macro-R: 28.98%,
Macro-P: 26.31%,
Macro-F1: 26.2%,
val_loss 2.021


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]72s/it]
Training, avg loss: 1.574:  58%|█████▊    | 35/60 [46:26<33:13, 79.74s/it]

ACC@1: 46.68%,
ACC@5: 75.78%,
Macro-R: 29.76%,
Macro-P: 28.38%,
Macro-F1: 27.66%,
val_loss 1.984


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]74s/it]
Training, avg loss: 1.550:  60%|██████    | 36/60 [47:46<31:52, 79.68s/it]

ACC@1: 48.7%,
ACC@5: 76.5%,
Macro-R: 31.76%,
Macro-P: 29.48%,
Macro-F1: 29.3%,
val_loss 1.939


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]68s/it]
Training, avg loss: 1.518:  62%|██████▏   | 37/60 [49:05<30:31, 79.63s/it]

ACC@1: 47.79%,
ACC@5: 78.45%,
Macro-R: 30.99%,
Macro-P: 28.38%,
Macro-F1: 28.07%,
val_loss 1.922


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]63s/it]
Training, avg loss: 1.472:  63%|██████▎   | 38/60 [50:25<29:13, 79.69s/it]

ACC@1: 48.63%,
ACC@5: 77.54%,
Macro-R: 31.62%,
Macro-P: 30.11%,
Macro-F1: 29.39%,
val_loss 1.921


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]69s/it]
Training, avg loss: 1.454:  65%|██████▌   | 39/60 [51:45<27:55, 79.77s/it]

ACC@1: 43.75%,
ACC@5: 74.54%,
Macro-R: 25.05%,
Macro-P: 23.17%,
Macro-F1: 22.86%,
val_loss 2.149


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]77s/it]
Training, avg loss: 1.430:  67%|██████▋   | 40/60 [53:05<26:36, 79.81s/it]

ACC@1: 50.13%,
ACC@5: 77.73%,
Macro-R: 32.17%,
Macro-P: 29.62%,
Macro-F1: 29.63%,
val_loss 1.874


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]81s/it]
Training, avg loss: 1.365:  68%|██████▊   | 41/60 [54:25<25:17, 79.86s/it]

ACC@1: 46.29%,
ACC@5: 77.08%,
Macro-R: 28.46%,
Macro-P: 26.42%,
Macro-F1: 25.99%,
val_loss 2.03


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]86s/it]
Training, avg loss: 1.361:  70%|███████   | 42/60 [55:45<23:58, 79.90s/it]

ACC@1: 51.04%,
ACC@5: 79.69%,
Macro-R: 32.93%,
Macro-P: 32.16%,
Macro-F1: 31.26%,
val_loss 1.864


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]90s/it]
Training, avg loss: 1.324:  72%|███████▏  | 43/60 [57:05<22:37, 79.86s/it]

ACC@1: 52.73%,
ACC@5: 79.49%,
Macro-R: 36.61%,
Macro-P: 34.05%,
Macro-F1: 33.88%,
val_loss 1.822


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]86s/it]
Training, avg loss: 1.299:  73%|███████▎  | 44/60 [58:25<21:18, 79.93s/it]

ACC@1: 52.02%,
ACC@5: 79.82%,
Macro-R: 34.74%,
Macro-P: 32.8%,
Macro-F1: 32.34%,
val_loss 1.819


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]93s/it]
Training, avg loss: 1.258:  75%|███████▌  | 45/60 [59:45<19:58, 79.90s/it]

ACC@1: 55.08%,
ACC@5: 79.62%,
Macro-R: 36.94%,
Macro-P: 35.47%,
Macro-F1: 34.72%,
val_loss 1.786


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.90s/it]
Training, avg loss: 1.235:  77%|███████▋  | 46/60 [1:01:04<18:38, 79.88s/it]

ACC@1: 51.63%,
ACC@5: 79.49%,
Macro-R: 33.82%,
Macro-P: 32.48%,
Macro-F1: 31.72%,
val_loss 1.823


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.88s/it]
Training, avg loss: 1.213:  78%|███████▊  | 47/60 [1:02:24<17:18, 79.89s/it]

ACC@1: 53.65%,
ACC@5: 80.66%,
Macro-R: 35.67%,
Macro-P: 35.02%,
Macro-F1: 33.96%,
val_loss 1.82


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]9.89s/it]
Training, avg loss: 1.174:  80%|████████  | 48/60 [1:03:44<15:58, 79.90s/it]

ACC@1: 54.04%,
ACC@5: 81.25%,
Macro-R: 37.58%,
Macro-P: 35.43%,
Macro-F1: 35.11%,
val_loss 1.79


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.90s/it]
Training, avg loss: 1.157:  82%|████████▏ | 49/60 [1:05:04<14:39, 79.92s/it]

ACC@1: 53.32%,
ACC@5: 80.79%,
Macro-R: 36.04%,
Macro-P: 33.82%,
Macro-F1: 33.55%,
val_loss 1.856


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.92s/it]
Training, avg loss: 1.099:  83%|████████▎ | 50/60 [1:06:24<13:18, 79.89s/it]

ACC@1: 55.86%,
ACC@5: 80.73%,
Macro-R: 37.74%,
Macro-P: 36.5%,
Macro-F1: 35.69%,
val_loss 1.754


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.89s/it]
Training, avg loss: 0.986:  85%|████████▌ | 51/60 [1:07:44<11:59, 79.94s/it]

ACC@1: 56.77%,
ACC@5: 83.01%,
Macro-R: 38.12%,
Macro-P: 36.91%,
Macro-F1: 36.25%,
val_loss 1.688


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.68it/s]9.94s/it]
Training, avg loss: 0.957:  87%|████████▋ | 52/60 [1:09:04<10:39, 79.96s/it]

ACC@1: 56.77%,
ACC@5: 83.14%,
Macro-R: 39.0%,
Macro-P: 38.54%,
Macro-F1: 37.43%,
val_loss 1.698


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.96s/it]
Training, avg loss: 0.951:  88%|████████▊ | 53/60 [1:10:24<09:19, 79.97s/it]

ACC@1: 58.07%,
ACC@5: 82.03%,
Macro-R: 41.26%,
Macro-P: 40.16%,
Macro-F1: 39.27%,
val_loss 1.705


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.97s/it]
Training, avg loss: 0.918:  90%|█████████ | 54/60 [1:11:44<07:59, 79.95s/it]

ACC@1: 58.92%,
ACC@5: 83.79%,
Macro-R: 40.99%,
Macro-P: 40.09%,
Macro-F1: 39.28%,
val_loss 1.658


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]9.95s/it]
Training, avg loss: 0.907:  92%|█████████▏| 55/60 [1:13:04<06:39, 79.94s/it]

ACC@1: 58.59%,
ACC@5: 83.4%,
Macro-R: 40.86%,
Macro-P: 40.45%,
Macro-F1: 39.39%,
val_loss 1.685


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.94s/it]
Training, avg loss: 0.885:  93%|█████████▎| 56/60 [1:14:24<05:19, 79.88s/it]

ACC@1: 57.81%,
ACC@5: 83.66%,
Macro-R: 40.66%,
Macro-P: 39.7%,
Macro-F1: 38.8%,
val_loss 1.697


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.88s/it]
Training, avg loss: 0.876:  95%|█████████▌| 57/60 [1:15:43<03:59, 79.82s/it]

ACC@1: 58.66%,
ACC@5: 83.53%,
Macro-R: 41.53%,
Macro-P: 40.46%,
Macro-F1: 39.76%,
val_loss 1.692


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.82s/it]
Training, avg loss: 0.853:  97%|█████████▋| 58/60 [1:17:03<02:39, 79.80s/it]

ACC@1: 58.66%,
ACC@5: 83.79%,
Macro-R: 41.25%,
Macro-P: 40.16%,
Macro-F1: 39.61%,
val_loss 1.69


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.69it/s]9.80s/it]
Training, avg loss: 0.841:  98%|█████████▊| 59/60 [1:18:23<01:19, 79.76s/it]

ACC@1: 58.4%,
ACC@5: 83.66%,
Macro-R: 41.67%,
Macro-P: 39.91%,
Macro-F1: 39.52%,
val_loss 1.67


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]9.76s/it]
Training, avg loss: 0.823: 100%|██████████| 60/60 [1:19:42<00:00, 79.71s/it]

ACC@1: 58.14%,
ACC@5: 82.68%,
Macro-R: 42.26%,
Macro-P: 40.44%,
Macro-F1: 39.95%,
val_loss 1.707





In [9]:
metrics, _ = test_user_model(model=trajfm, dataloader=test_dataloader, device = device)
for key, value in metrics.items():
    print(f"{key}: {round(value * 100, 2)}%,")

Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]

ACC@1: 58.05%,
ACC@5: 82.58%,
Macro-R: 42.14%,
Macro-P: 40.75%,
Macro-F1: 40.14%,





In [10]:
df = pd.DataFrame([{
    "Model": f"{SAVE_NAME}",
    **{key: round(value * 100, 2) for key, value in metrics.items()}
}])

csv_path = "logs/test.csv"
if os.path.exists(csv_path):
    df.to_csv(csv_path, mode='a', header=False, index=False)
else:
    df.to_csv(csv_path, index=False)