In [1]:
import os
import sys
sys.path.append(os.path.abspath('../'))
from os import makedirs
from os.path import join, basename
import numpy as np
import torch
import random
from arguments import define_new_main_parser
import json

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

from dataset.dataset import Dataset
from models.modules import TabFormerBertLM, TabFormerBertForClassification, TabFormerBertModel, TabStaticFormerBert, \
    TabStaticFormerBertLM, TabStaticFormerBertClassification
from misc.utils import ordered_split_dataset, compute_cls_metrics
from dataset.datacollator import *
from main import main

import logging

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def setup_logging(output_dir="logs", log_file_name='output.log'):
    log_dir = join(output_dir, "logs")
    makedirs(output_dir, exist_ok=True)
    makedirs(log_dir, exist_ok=True)
    log_file = join(log_dir, log_file_name)

    logger = logging.getLogger()

    if logger.hasHandlers():
        logger.handlers.clear()

    fhandler = logging.FileHandler(log_file)
    fhandler.setLevel(logging.DEBUG)

    chandler = logging.StreamHandler()
    chandler.setLevel(logging.DEBUG)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    fhandler.setFormatter(formatter)
    chandler.setFormatter(formatter)

    logger.addHandler(fhandler)
    logger.addHandler(chandler)
    logger.setLevel(logging.DEBUG)

    return logger

logger = setup_logging(output_dir="logs")


logger.info("Logging setup completed.")

2025-02-07 18:19:05,241 - root - INFO - Logging setup completed.


In [3]:
include_user_features = True
include_time_features = True
include_market_features = False
include_exo_features = False

feature_extension = ""
if include_user_features:
    feature_extension += "_user"
if include_market_features:
    feature_extension += "_market"
if include_time_features:
    feature_extension += "_time"
if include_exo_features:
    feature_extension += "_exoLagged"

In [4]:
data="/data/IDEA_DeFi_Research/Data/AML/LI_Small/preprocessed" 
dt="aml"
exp_name="debug"
time_pos_type="regular_position"
fname = f"transactions{feature_extension}_train"  
val_fname = f"transactions{feature_extension}_val" 
test_fname = f"transactions{feature_extension}_test"  
fextension = False
bs=32
field_hs = 64 # hidden state dimension of the fields in the transformer (default: 768)
seq_len = 25 # length for transaction sliding window
stride = 5 # stride for transaction sliding window
num_train_epochs=5
save_steps=5000
eval_steps=5000
external_val=False
output_dir=f"{data}/output/{exp_name}"
checkpoint=None
nrows=2000000
vocab_dir=f"{data}/vocab"

In [5]:
arg_str = f"--do_train \
    --mlm \
    --pad_seq_first \
    --get_rids \
    --field_ce \
    --lm_type bert \
    --field_hs {field_hs} \
    --data_type {dt} \
    --seq_len {seq_len} \
    --stride {stride} \
    --num_train_epochs {num_train_epochs} \
    --data_root {data}/ \
    --train_batch_size {bs} \
    --eval_batch_size {bs} \
    --save_steps {save_steps} \
    --eval_steps {eval_steps} \
    --data_fname {fname} \
    --data_val_fname {val_fname} \
    --data_test_fname {test_fname} \
    --output_dir {output_dir} \
    --time_pos_type {time_pos_type} \
    --vocab_dir {vocab_dir} \
    --nrows {nrows} \
    --vocab_cached \
    --encoder_cached \
    --cached \
    "
   # 
if fextension:
    arg_str += f"--fextension {fextension} \
    --external_vocab_path {data}/vocab_ob_{fextension}"
else:
    arg_str += f"--external_vocab_path {data}/vocab/vocab_ob"
if external_val:
    arg_str += f"\
    --external_val"
if checkpoint is not None:
    arg_str += f"\
    --checkpoint {checkpoint}"

In [6]:
parser = define_new_main_parser()
opts = parser.parse_args(arg_str.split())

In [None]:
opts.log_dir = join(opts.output_dir, "logs")
makedirs(opts.output_dir, exist_ok=True)
makedirs(opts.log_dir, exist_ok=True)

opts.cls_exp_task = opts.cls_task or opts.export_task

if (not opts.mlm) and (not opts.cls_exp_task) and opts.lm_type == "bert":
    raise Exception(
        "Error: Bert needs either '--mlm', '--cls_task' or '--export_task' option. Please re-run with this flag "
        "included.")

main(opts)

2025-02-07 18:19:05,289 - dataset.basic - INFO - cached encoded data is read from transactions_user_time_train.encoded.csv
2025-02-07 18:19:08,028 - dataset.basic - INFO - read data : (2000000, 48)
2025-02-07 18:19:08,030 - dataset.basic - INFO - using cached vocab from /data/IDEA_DeFi_Research/Data/eCommerce/Cosmetics/preprocessed/vocab/vocab_ob
2025-02-07 18:19:08,100 - dataset.dataset - INFO - preparing user level data...
100%|██████████| 146723/146723 [01:06<00:00, 2215.09it/s]
2025-02-07 18:20:18,469 - dataset.dataset - INFO - creating transaction samples with vocab
100%|██████████| 146723/146723 [00:56<00:00, 2598.56it/s]
2025-02-07 18:21:14,941 - dataset.dataset - INFO - ncols: 45
2025-02-07 18:21:14,941 - dataset.dataset - INFO - no of samples 488990
2025-02-07 18:21:15,493 - main - INFO - vocab size: 645
2025-02-07 18:21:15,497 - main - INFO - dataset size: 488990
2025-02-07 18:21:15,540 - dataset.basic - INFO - cached encoded data is read from transactions_user_time_train.enc

Step,Training Loss,Validation Loss
5000,43.7601,44.326366
