In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import utils
from models.trajfm import TrajFM
from data import TrajFMDataset, fetch_task_padder, X_COL, Y_COL, coord_transform_GPS_UTM
import warnings
from pipeline import train_user_model, test_user_model
import torch.multiprocessing as mp
warnings.filterwarnings('ignore')

In [2]:
SETTINGS_CACHE_DIR = os.environ.get('SETTINGS_CACHE_DIR', os.path.join('settings', 'cache'))
MODEL_CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', 'saved_model')
LOG_SAVE_DIR = os.environ.get('LOG_SAVE_DIR', 'logs')
PRED_SAVE_DIR = os.environ.get('PRED_SAVE_DIR', 'predictions')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'


if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')
device = f'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
# This key is an indicator of multiple things.
datetime_key = utils.get_datetime_key()


with open(os.path.join('settings', f'local_test.json'), 'r') as fp:
    setting = json.load(fp)
    setting = setting[0]
utils.create_if_noexists(SETTINGS_CACHE_DIR)
with open(os.path.join(SETTINGS_CACHE_DIR, f'{datetime_key}.json'), 'w') as fp:
    json.dump(setting, fp)
    
print("device:", device)

device: cuda


In [3]:
SAVE_NAME = setting["save_name"]

train_traj_df = pd.read_hdf(setting['dataset']['train_traj_df'], key='trips')
print("dataset:", setting['dataset']['train_traj_df'])
user_count = len(train_traj_df['user_id'].unique())
traj_count = len(train_traj_df['traj_id'].unique())
traj_len = len(train_traj_df['seq_i'])
tao = train_traj_df['delta_t'].mean()
setting["finetune"]["padder"]["params"]["num_users"] = user_count

scale = 4000
if "geolife" in setting['dataset']['train_traj_df']:
    UTM_region = 50  
dataset = TrajFMDataset(traj_df=train_traj_df, UTM_region=UTM_region, scale = scale)

poi_df = pd.read_hdf(setting['dataset']['poi_df'], key='pois')
poi_embed = torch.from_numpy(np.load(setting['dataset']['poi_embed'])).float().to(device)

poi_coors = poi_df[[X_COL, Y_COL]].to_numpy()
poi_coors = (coord_transform_GPS_UTM(poi_coors, UTM_region) - dataset.spatial_middle_coord) / scale
poi_coors = torch.tensor(poi_coors).float().to(device)

# Build the learnable model.
trajfm = TrajFM(poi_embed=poi_embed, 
                poi_coors=poi_coors, 
                UTM_region=UTM_region,
                spatial_middle_coord = dataset.spatial_middle_coord, 
                scale = scale, 
                **setting['trajfm'],
                user = user_count).to(device)

dataset: ./dataset/geolife_U56_TrajAll_L1000.h5


In [4]:
data_summary = {
    "users": user_count,
    "total_traj": traj_count,           
    "total_points": traj_len,
    "avg_traj_len": f"{round(tao/3600, 2)} hours",
    "Data Filtering": [ 
        "25th to 75th quartile based on traj_len", 
        "traj_len > 30 points", 
        "delta_t > 1800s",
        "traj/user > 35 traj",
        "resampled traj to 1000 points max",
        "user_number and seq_i recalculated",
        "stratified",]
}

for key, value in data_summary.items():
    print(key, ":", value)

users : 56
total_traj : 7602
total_points : 5450141
avg_traj_len : 1.62 hours
Data Filtering : ['25th to 75th quartile based on traj_len', 'traj_len > 30 points', 'delta_t > 1800s', 'traj/user > 35 traj', 'resampled traj to 1000 points max', 'user_number and seq_i recalculated', 'stratified']


In [5]:
train_dataset, val_test_dataset = utils.stratify_dataset(dataset = dataset, test_size = 0.4, random_seed = SEED)
val_dataset, test_dataset = utils.stratify_dataset(dataset = val_test_dataset, test_size = 0.5, random_seed = SEED)

In [6]:
# print(len(train_dataset), len(val_dataset), len(test_dataset))

In [7]:
downstreamtask = setting['finetune']['padder']['name']
padder = fetch_task_padder(padder_name=setting['finetune']['padder']['name'], padder_params=setting['finetune']['padder']['params'])

train_dataloader = DataLoader(train_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
val_dataloader = DataLoader(val_dataset, collate_fn=padder, **setting['finetune']['dataloader'])
test_dataloader = DataLoader(test_dataset, collate_fn=padder, **setting['finetune']['dataloader'])

In [8]:
file_path = f"saved_model/{SAVE_NAME}.{downstreamtask}"
if os.path.exists(file_path):
    print(f"Loading model {file_path}")
    trajfm.load_state_dict(torch.load(os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'), map_location=device))
else:
    print("Model not found, starting new")

train_log, saved_model_state_dict = train_user_model(model=trajfm, 
                                                    train_dataloader=train_dataloader, 
                                                    val_dataloader=val_dataloader,
                                                    device = device, 
                                                    **setting['finetune']['config'],
                                                    data_summary = data_summary)

if setting['finetune'].get('save', False):
    # save model
    utils.create_if_noexists(MODEL_CACHE_DIR)
    torch.save(saved_model_state_dict, os.path.join(MODEL_CACHE_DIR, f'{SAVE_NAME}.{downstreamtask}'))
    
    # save log
    log_dir = os.path.join(LOG_SAVE_DIR, SAVE_NAME)
    utils.create_if_noexists(log_dir)
    log_path = os.path.join(log_dir, f'{SAVE_NAME}_{downstreamtask}.csv')
    file_exists = os.path.exists(log_path)
    train_log.to_csv(log_path, mode='a', header=not file_exists, index=False)

Model not found, starting new


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meuj01[0m ([33mSP_001[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The run id is U56_TrajAll_L1000_v3.7s


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]
Training, avg loss: 3.650:   2%|▏         | 1/60 [01:20<1:18:48, 80.14s/it]

ACC@1: 18.42%,
ACC@5: 39.13%,
Macro-R: 6.92%,
Macro-P: 1.87%,
Macro-F1: 2.81%,
val_loss 3.503


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s].14s/it]
Training, avg loss: 3.433:   3%|▎         | 2/60 [02:38<1:16:42, 79.36s/it]

ACC@1: 17.71%,
ACC@5: 40.1%,
Macro-R: 5.24%,
Macro-P: 1.09%,
Macro-F1: 1.63%,
val_loss 3.43


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].36s/it]
Training, avg loss: 3.323:   5%|▌         | 3/60 [03:57<1:15:03, 79.02s/it]

ACC@1: 17.12%,
ACC@5: 38.8%,
Macro-R: 5.16%,
Macro-P: 1.11%,
Macro-F1: 1.71%,
val_loss 3.494


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].02s/it]
Training, avg loss: 3.251:   7%|▋         | 4/60 [05:16<1:13:37, 78.88s/it]

ACC@1: 21.55%,
ACC@5: 43.62%,
Macro-R: 8.66%,
Macro-P: 4.04%,
Macro-F1: 4.85%,
val_loss 3.257


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].88s/it]
Training, avg loss: 3.193:   8%|▊         | 5/60 [06:35<1:12:18, 78.89s/it]

ACC@1: 19.99%,
ACC@5: 44.6%,
Macro-R: 7.96%,
Macro-P: 3.89%,
Macro-F1: 4.69%,
val_loss 3.239


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s].89s/it]
Training, avg loss: 3.121:  10%|█         | 6/60 [07:53<1:10:59, 78.87s/it]

ACC@1: 18.62%,
ACC@5: 41.8%,
Macro-R: 7.9%,
Macro-P: 3.86%,
Macro-F1: 4.88%,
val_loss 3.345


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s].87s/it]
Training, avg loss: 3.089:  12%|█▏        | 7/60 [09:12<1:09:37, 78.82s/it]

ACC@1: 20.38%,
ACC@5: 45.64%,
Macro-R: 7.73%,
Macro-P: 3.26%,
Macro-F1: 4.15%,
val_loss 3.186


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s].82s/it]
Training, avg loss: 3.064:  13%|█▎        | 8/60 [10:31<1:08:16, 78.77s/it]

ACC@1: 19.21%,
ACC@5: 45.83%,
Macro-R: 6.5%,
Macro-P: 2.89%,
Macro-F1: 3.6%,
val_loss 3.186


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s].77s/it]
Training, avg loss: 3.008:  15%|█▌        | 9/60 [11:50<1:07:03, 78.90s/it]

ACC@1: 20.31%,
ACC@5: 44.27%,
Macro-R: 7.89%,
Macro-P: 3.84%,
Macro-F1: 4.6%,
val_loss 3.193


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.71it/s].90s/it]
Training, avg loss: 2.978:  17%|█▋        | 10/60 [13:09<1:05:47, 78.95s/it]

ACC@1: 18.23%,
ACC@5: 38.67%,
Macro-R: 6.99%,
Macro-P: 3.22%,
Macro-F1: 3.91%,
val_loss 3.344


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]8.95s/it]
Training, avg loss: 2.948:  18%|█▊        | 11/60 [14:28<1:04:25, 78.90s/it]

ACC@1: 20.64%,
ACC@5: 43.55%,
Macro-R: 8.49%,
Macro-P: 4.34%,
Macro-F1: 5.14%,
val_loss 3.169


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]8.90s/it]
Training, avg loss: 2.922:  20%|██        | 12/60 [15:46<1:02:59, 78.74s/it]

ACC@1: 20.83%,
ACC@5: 43.75%,
Macro-R: 8.43%,
Macro-P: 4.34%,
Macro-F1: 5.13%,
val_loss 3.186


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]8.74s/it]
Training, avg loss: 2.879:  22%|██▏       | 13/60 [17:05<1:01:41, 78.77s/it]

ACC@1: 21.48%,
ACC@5: 41.73%,
Macro-R: 9.64%,
Macro-P: 4.6%,
Macro-F1: 5.54%,
val_loss 3.231


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]8.77s/it]
Training, avg loss: 2.862:  23%|██▎       | 14/60 [18:24<1:00:27, 78.86s/it]

ACC@1: 20.44%,
ACC@5: 40.43%,
Macro-R: 8.46%,
Macro-P: 4.57%,
Macro-F1: 5.41%,
val_loss 3.258


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]8.86s/it]
Training, avg loss: 2.842:  25%|██▌       | 15/60 [19:43<59:03, 78.75s/it]  

ACC@1: 20.57%,
ACC@5: 40.76%,
Macro-R: 9.56%,
Macro-P: 5.4%,
Macro-F1: 5.96%,
val_loss 3.302


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]75s/it]
Training, avg loss: 2.833:  27%|██▋       | 16/60 [21:01<57:42, 78.69s/it]

ACC@1: 21.22%,
ACC@5: 43.88%,
Macro-R: 9.35%,
Macro-P: 5.61%,
Macro-F1: 6.31%,
val_loss 3.17


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]69s/it]
Training, avg loss: 2.782:  28%|██▊       | 17/60 [22:20<56:24, 78.71s/it]

ACC@1: 21.94%,
ACC@5: 43.75%,
Macro-R: 10.18%,
Macro-P: 6.1%,
Macro-F1: 6.98%,
val_loss 3.152


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]71s/it]
Training, avg loss: 2.804:  30%|███       | 18/60 [23:39<55:10, 78.83s/it]

ACC@1: 19.99%,
ACC@5: 39.91%,
Macro-R: 8.66%,
Macro-P: 5.08%,
Macro-F1: 5.82%,
val_loss 3.348


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.71it/s]83s/it]
Training, avg loss: 2.783:  32%|███▏      | 19/60 [24:58<53:54, 78.89s/it]

ACC@1: 21.09%,
ACC@5: 44.21%,
Macro-R: 9.46%,
Macro-P: 5.7%,
Macro-F1: 6.44%,
val_loss 3.196


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]89s/it]
Training, avg loss: 2.743:  33%|███▎      | 20/60 [26:17<52:36, 78.90s/it]

ACC@1: 22.01%,
ACC@5: 45.44%,
Macro-R: 9.87%,
Macro-P: 6.12%,
Macro-F1: 6.84%,
val_loss 3.114


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]90s/it]
Training, avg loss: 2.702:  35%|███▌      | 21/60 [27:36<51:14, 78.85s/it]

ACC@1: 22.98%,
ACC@5: 45.31%,
Macro-R: 11.65%,
Macro-P: 6.82%,
Macro-F1: 7.99%,
val_loss 3.081


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]85s/it]
Training, avg loss: 2.670:  37%|███▋      | 22/60 [28:55<49:56, 78.85s/it]

ACC@1: 22.2%,
ACC@5: 44.14%,
Macro-R: 11.09%,
Macro-P: 6.99%,
Macro-F1: 7.92%,
val_loss 3.134


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]85s/it]
Training, avg loss: 2.658:  38%|███▊      | 23/60 [30:14<48:38, 78.88s/it]

ACC@1: 23.18%,
ACC@5: 44.47%,
Macro-R: 10.94%,
Macro-P: 6.56%,
Macro-F1: 7.6%,
val_loss 3.133


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]88s/it]
Training, avg loss: 2.617:  40%|████      | 24/60 [31:32<47:17, 78.82s/it]

ACC@1: 22.46%,
ACC@5: 43.16%,
Macro-R: 10.94%,
Macro-P: 6.34%,
Macro-F1: 7.38%,
val_loss 3.148


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]82s/it]
Training, avg loss: 2.624:  42%|████▏     | 25/60 [32:51<45:57, 78.79s/it]

ACC@1: 21.35%,
ACC@5: 44.14%,
Macro-R: 9.93%,
Macro-P: 6.5%,
Macro-F1: 7.1%,
val_loss 3.148


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]79s/it]
Training, avg loss: 2.580:  43%|████▎     | 26/60 [34:10<44:43, 78.93s/it]

ACC@1: 22.59%,
ACC@5: 46.35%,
Macro-R: 11.2%,
Macro-P: 8.42%,
Macro-F1: 8.72%,
val_loss 3.08


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]93s/it]
Training, avg loss: 2.551:  45%|████▌     | 27/60 [35:29<43:25, 78.94s/it]

ACC@1: 21.74%,
ACC@5: 42.19%,
Macro-R: 10.76%,
Macro-P: 7.31%,
Macro-F1: 7.84%,
val_loss 3.213


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]94s/it]
Training, avg loss: 2.529:  47%|████▋     | 28/60 [36:48<42:05, 78.93s/it]

ACC@1: 21.03%,
ACC@5: 43.42%,
Macro-R: 10.39%,
Macro-P: 8.08%,
Macro-F1: 8.0%,
val_loss 3.167


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]93s/it]
Training, avg loss: 2.472:  48%|████▊     | 29/60 [38:07<40:48, 78.99s/it]

ACC@1: 22.2%,
ACC@5: 45.25%,
Macro-R: 11.45%,
Macro-P: 8.47%,
Macro-F1: 8.66%,
val_loss 3.143


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]99s/it]
Training, avg loss: 2.436:  50%|█████     | 30/60 [39:26<39:31, 79.05s/it]

ACC@1: 22.79%,
ACC@5: 47.53%,
Macro-R: 11.12%,
Macro-P: 7.92%,
Macro-F1: 8.27%,
val_loss 3.058


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]05s/it]
Training, avg loss: 2.374:  52%|█████▏    | 31/60 [40:45<38:12, 79.06s/it]

ACC@1: 24.93%,
ACC@5: 51.5%,
Macro-R: 12.33%,
Macro-P: 8.61%,
Macro-F1: 9.12%,
val_loss 2.994


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]06s/it]
Training, avg loss: 2.309:  53%|█████▎    | 32/60 [42:04<36:51, 78.98s/it]

ACC@1: 25.78%,
ACC@5: 50.39%,
Macro-R: 14.76%,
Macro-P: 11.84%,
Macro-F1: 11.65%,
val_loss 2.94


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]98s/it]
Training, avg loss: 2.243:  55%|█████▌    | 33/60 [43:23<35:31, 78.94s/it]

ACC@1: 34.24%,
ACC@5: 59.9%,
Macro-R: 19.07%,
Macro-P: 15.85%,
Macro-F1: 16.03%,
val_loss 2.575


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]94s/it]
Training, avg loss: 2.170:  57%|█████▋    | 34/60 [44:42<34:11, 78.89s/it]

ACC@1: 34.7%,
ACC@5: 60.03%,
Macro-R: 20.59%,
Macro-P: 17.1%,
Macro-F1: 17.06%,
val_loss 2.569


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]89s/it]
Training, avg loss: 2.080:  58%|█████▊    | 35/60 [46:01<32:50, 78.83s/it]

ACC@1: 33.53%,
ACC@5: 60.81%,
Macro-R: 18.17%,
Macro-P: 14.77%,
Macro-F1: 15.07%,
val_loss 2.59


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]83s/it]
Training, avg loss: 2.016:  60%|██████    | 36/60 [47:19<31:31, 78.82s/it]

ACC@1: 35.74%,
ACC@5: 61.46%,
Macro-R: 21.98%,
Macro-P: 18.19%,
Macro-F1: 18.33%,
val_loss 2.512


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]82s/it]
Training, avg loss: 1.938:  62%|██████▏   | 37/60 [48:38<30:14, 78.87s/it]

ACC@1: 36.85%,
ACC@5: 62.04%,
Macro-R: 20.98%,
Macro-P: 17.13%,
Macro-F1: 17.52%,
val_loss 2.516


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]87s/it]
Training, avg loss: 1.876:  63%|██████▎   | 38/60 [49:57<28:53, 78.79s/it]

ACC@1: 37.63%,
ACC@5: 64.84%,
Macro-R: 23.42%,
Macro-P: 20.16%,
Macro-F1: 19.79%,
val_loss 2.417


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]79s/it]
Training, avg loss: 1.817:  65%|██████▌   | 39/60 [51:16<27:35, 78.82s/it]

ACC@1: 36.65%,
ACC@5: 63.74%,
Macro-R: 21.94%,
Macro-P: 18.73%,
Macro-F1: 18.45%,
val_loss 2.484


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]82s/it]
Training, avg loss: 1.769:  67%|██████▋   | 40/60 [52:35<26:17, 78.86s/it]

ACC@1: 41.67%,
ACC@5: 69.6%,
Macro-R: 24.38%,
Macro-P: 21.65%,
Macro-F1: 21.79%,
val_loss 2.214


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]86s/it]
Training, avg loss: 1.686:  68%|██████▊   | 41/60 [53:54<25:00, 78.95s/it]

ACC@1: 40.49%,
ACC@5: 67.64%,
Macro-R: 24.55%,
Macro-P: 20.67%,
Macro-F1: 21.03%,
val_loss 2.275


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]95s/it]
Training, avg loss: 1.658:  70%|███████   | 42/60 [55:13<23:40, 78.94s/it]

ACC@1: 44.14%,
ACC@5: 73.31%,
Macro-R: 28.32%,
Macro-P: 25.81%,
Macro-F1: 25.44%,
val_loss 2.122


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]94s/it]
Training, avg loss: 1.563:  72%|███████▏  | 43/60 [56:32<22:20, 78.84s/it]

ACC@1: 49.35%,
ACC@5: 78.71%,
Macro-R: 33.26%,
Macro-P: 29.46%,
Macro-F1: 29.65%,
val_loss 1.871


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.74it/s]84s/it]
Training, avg loss: 1.493:  73%|███████▎  | 44/60 [57:50<21:00, 78.79s/it]

ACC@1: 44.21%,
ACC@5: 73.37%,
Macro-R: 28.09%,
Macro-P: 24.82%,
Macro-F1: 24.99%,
val_loss 2.094


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]79s/it]
Training, avg loss: 1.414:  75%|███████▌  | 45/60 [59:09<19:42, 78.82s/it]

ACC@1: 47.27%,
ACC@5: 76.56%,
Macro-R: 31.11%,
Macro-P: 27.05%,
Macro-F1: 27.29%,
val_loss 1.987


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]8.82s/it]
Training, avg loss: 1.406:  77%|███████▋  | 46/60 [1:00:28<18:23, 78.79s/it]

ACC@1: 54.04%,
ACC@5: 80.86%,
Macro-R: 38.98%,
Macro-P: 37.08%,
Macro-F1: 36.4%,
val_loss 1.767


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.75it/s]8.79s/it]
Training, avg loss: 1.341:  78%|███████▊  | 47/60 [1:01:46<17:02, 78.69s/it]

ACC@1: 48.83%,
ACC@5: 79.23%,
Macro-R: 32.88%,
Macro-P: 28.96%,
Macro-F1: 28.98%,
val_loss 1.898


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]8.69s/it]
Training, avg loss: 1.260:  80%|████████  | 48/60 [1:03:05<15:44, 78.73s/it]

ACC@1: 52.67%,
ACC@5: 80.21%,
Macro-R: 37.05%,
Macro-P: 34.01%,
Macro-F1: 34.01%,
val_loss 1.8


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]8.73s/it]
Training, avg loss: 1.206:  82%|████████▏ | 49/60 [1:04:24<14:27, 78.82s/it]

ACC@1: 57.03%,
ACC@5: 82.55%,
Macro-R: 40.43%,
Macro-P: 38.41%,
Macro-F1: 37.93%,
val_loss 1.68


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]8.82s/it]
Training, avg loss: 1.146:  83%|████████▎ | 50/60 [1:05:44<13:10, 79.00s/it]

ACC@1: 56.25%,
ACC@5: 83.4%,
Macro-R: 40.19%,
Macro-P: 37.57%,
Macro-F1: 37.27%,
val_loss 1.685


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]9.00s/it]
Training, avg loss: 1.113:  85%|████████▌ | 51/60 [1:07:02<11:50, 78.95s/it]

ACC@1: 60.68%,
ACC@5: 85.16%,
Macro-R: 44.35%,
Macro-P: 42.71%,
Macro-F1: 42.13%,
val_loss 1.537


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]8.95s/it]
Training, avg loss: 1.048:  87%|████████▋ | 52/60 [1:08:21<10:31, 78.90s/it]

ACC@1: 55.53%,
ACC@5: 81.9%,
Macro-R: 39.39%,
Macro-P: 37.88%,
Macro-F1: 37.14%,
val_loss 1.723


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]8.90s/it]
Training, avg loss: 0.991:  88%|████████▊ | 53/60 [1:09:40<09:12, 78.98s/it]

ACC@1: 60.55%,
ACC@5: 84.57%,
Macro-R: 44.57%,
Macro-P: 42.68%,
Macro-F1: 42.17%,
val_loss 1.497


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]8.98s/it]
Training, avg loss: 0.948:  90%|█████████ | 54/60 [1:11:00<07:54, 79.06s/it]

ACC@1: 61.85%,
ACC@5: 85.81%,
Macro-R: 46.38%,
Macro-P: 44.97%,
Macro-F1: 44.13%,
val_loss 1.481


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.71it/s]9.06s/it]
Training, avg loss: 0.918:  92%|█████████▏| 55/60 [1:12:19<06:35, 79.08s/it]

ACC@1: 63.15%,
ACC@5: 85.61%,
Macro-R: 47.83%,
Macro-P: 46.13%,
Macro-F1: 45.59%,
val_loss 1.453


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]9.08s/it]
Training, avg loss: 0.860:  93%|█████████▎| 56/60 [1:13:38<05:16, 79.01s/it]

ACC@1: 63.93%,
ACC@5: 85.87%,
Macro-R: 49.85%,
Macro-P: 47.58%,
Macro-F1: 47.18%,
val_loss 1.43


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]9.01s/it]
Training, avg loss: 0.825:  95%|█████████▌| 57/60 [1:14:56<03:56, 79.00s/it]

ACC@1: 60.29%,
ACC@5: 84.96%,
Macro-R: 44.44%,
Macro-P: 43.42%,
Macro-F1: 42.56%,
val_loss 1.596


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.73it/s]9.00s/it]
Training, avg loss: 0.781:  97%|█████████▋| 58/60 [1:16:16<02:38, 79.01s/it]

ACC@1: 61.46%,
ACC@5: 85.74%,
Macro-R: 46.8%,
Macro-P: 45.28%,
Macro-F1: 44.53%,
val_loss 1.561


Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]9.01s/it]
Training, avg loss: 0.749:  98%|█████████▊| 59/60 [1:17:35<01:19, 79.04s/it]

ACC@1: 63.54%,
ACC@5: 86.07%,
Macro-R: 49.39%,
Macro-P: 46.75%,
Macro-F1: 46.82%,
val_loss 1.449


Testing/Validating: 100%|██████████| 48/48 [00:28<00:00,  1.71it/s]9.04s/it]
Training, avg loss: 0.726: 100%|██████████| 60/60 [1:18:54<00:00, 78.91s/it]

ACC@1: 63.74%,
ACC@5: 86.59%,
Macro-R: 49.13%,
Macro-P: 47.46%,
Macro-F1: 46.95%,
val_loss 1.468





In [9]:
metrics, _ = test_user_model(model=trajfm, dataloader=test_dataloader, device = device)
for key, value in metrics.items():
    print(f"{key}: {round(value * 100, 2)}%,")

Testing/Validating: 100%|██████████| 48/48 [00:27<00:00,  1.72it/s]

ACC@1: 63.6%,
ACC@5: 86.21%,
Macro-R: 49.36%,
Macro-P: 48.47%,
Macro-F1: 47.57%,





In [10]:
df = pd.DataFrame([{
    "Model": f"{SAVE_NAME}",
    **{key: round(value * 100, 2) for key, value in metrics.items()}
}])

csv_path = "logs/test.csv"
if os.path.exists(csv_path):
    df.to_csv(csv_path, mode='a', header=False, index=False)
else:
    df.to_csv(csv_path, index=False)