In [8]:
# edited by Dongyu Zhang
from os import makedirs
from os.path import join, basename
import logging
import numpy as np
import torch
import random
from args import define_new_main_parser
import json

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

from dataset.aave import AaveDataset
from dataset.aave_time_static import AaveWithTimePosAndStaticSplitDataset
from dataset.aave_time_pos import AaveWithTimePosDataset
from dataset.aave_static import AaveWithStaticSplitDataset
from models.modules import TabFormerBertLM, TabFormerBertForClassification, TabFormerBertModel, TabStaticFormerBert, \
    TabStaticFormerBertLM, TabStaticFormerBertClassification
from misc.utils import ordered_split_dataset, compute_cls_metrics
from dataset.datacollator import *
from main_aave import main

logger = logging.getLogger(__name__)
log = logger
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

import os
os.environ["WANDB_DISABLED"] = "true"

In [13]:
data=""
dt="Aave"
time_pos_type="regular_position"
fextension=False
fname="transactionsAave_train"
val_fname="transactionsAave_val"
test_fname="transactionsAave_test"
preload_fextension="preload-test"
bs=32
nb=10
save_steps=200
eval_steps=200
pretrained_dir="output_aave/final-model"
resample_method="downsample"
resample_ratio=10
resample_seed=100
external_val=False
output_dir="output"
checkpoint=None
seed=9

In [14]:
arg_str = f"--do_train \
    --do_eval \
    --export_task \
    --long_and_sort \
    --pad_seq_first \
    --get_rids \
    --field_ce \
    --lm_type bert \
    --field_hs 64 \
    --data_type {dt} \
    --stride 5 \
    --data_root ./data/{data}/ \
    --train_batch_size {bs} \
    --eval_batch_size {bs} \
    --save_steps {save_steps} \
    --eval_steps {eval_steps} \
    --nbatches {nb} \
    --data_fname {fname} \
    --data_val_fname {val_fname} \
    --data_test_fname {test_fname} \
    --vocab_cached \
    --user_level_cached \
    --pretrained_dir {pretrained_dir} \
    --output_dir {output_dir} \
    --time_pos_type {time_pos_type} \
    --resample_ratio {resample_ratio} \
    --resample_seed {resample_seed} \
    --seed {seed} \
    "
if fextension:
    arg_str += f"--fextension {fextension} \
    --external_vocab_path ./data/{data}/vocab_ob_{fextension}"
else:
    arg_str += f"--external_vocab_path ./data/{data}/vocab_ob"
if resample_method is not None:
    arg_str += f"\
    --resample_method {resample_method}"
if external_val:
    arg_str += f"\
    --external_val"
if checkpoint is not None:
    arg_str += f"\
    --checkpoint {checkpoint}"

In [15]:
parser = define_new_main_parser(data_type_choices=["Aave", "Aave_time_pos", "Aave_time_static", "Aave_static"])
opts = parser.parse_args(arg_str.split())

In [16]:
opts.log_dir = join(opts.output_dir, "logs")
makedirs(opts.output_dir, exist_ok=True)
makedirs(opts.log_dir, exist_ok=True)

file_handler = logging.FileHandler(
    join(opts.log_dir, 'output.log'), 'w', 'utf-8')
logger.addHandler(file_handler)

opts.cls_exp_task = opts.cls_task or opts.export_task

if opts.data_type in ["Aave_time_pos", "Aave_time_static"]:
    assert opts.time_pos_type == 'time_aware_sin_cos_position'
elif opts.data_type in ["Aave", "Aave_static"]:
    assert opts.time_pos_type in ['sin_cos_position', 'regular_position']

if opts.mlm and opts.lm_type == "gpt2":
    raise Exception(
        "Error: GPT2 doesn't need '--mlm' option. Please re-run with this flag removed.")

if (not opts.mlm) and (not opts.cls_exp_task) and opts.lm_type == "bert":
    raise Exception(
        "Error: Bert needs either '--mlm', '--cls_task' or '--export_task' option. Please re-run with this flag "
        "included.")

main(opts)

11/20/2024 16:02:01 - INFO - dataset.aave_basic -   cached encoded data is read from transactionsAave_train.encoded.csv
11/20/2024 16:02:01 - INFO - dataset.aave_basic -   read data : (10000, 42)
11/20/2024 16:02:01 - INFO - dataset.aave -   preparing user level data...
100%|██████████| 18/18 [00:00<00:00, 52.38it/s]
11/20/2024 16:02:01 - INFO - dataset.aave -   creating transaction samples with vocab
100%|██████████| 18/18 [00:00<00:00, 111.50it/s]
11/20/2024 16:02:02 - INFO - dataset.aave -   ncols: 38
11/20/2024 16:02:02 - INFO - dataset.aave -   no of samples 2011
11/20/2024 16:02:02 - INFO - dataset.aave_basic -   saving vocab at output/vocab.nb
11/20/2024 16:02:02 - INFO - dataset.aave_basic -   saving vocab object at output/vocab_ob
11/20/2024 16:02:02 - INFO - main_aave -   vocab size: 335
11/20/2024 16:02:02 - INFO - dataset.aave_basic -   cached encoded data is read from transactionsAave_test.encoded.csv
11/20/2024 16:02:02 - INFO - dataset.aave_basic -   read data : (1000, 4

11/20/2024 16:02:20 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:02:20 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_0_embeddings


11/20/2024 16:02:41 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:02:41 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_1_embeddings


11/20/2024 16:03:00 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:03:00 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_2_embeddings


11/20/2024 16:03:18 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:03:18 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_3_embeddings


11/20/2024 16:03:37 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:03:37 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_4_embeddings


11/20/2024 16:03:57 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:03:57 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_5_embeddings


11/20/2024 16:04:16 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:04:16 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_6_embeddings


11/20/2024 16:04:35 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:04:35 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_7_embeddings


11/20/2024 16:04:54 - INFO - main_aave -   row embeds shape: (97, 10, 2432)
11/20/2024 16:04:54 - INFO - main_aave -   seq embeds shape: (97, 10, 2432)


saved file output/batch_8_embeddings


11/20/2024 16:05:15 - INFO - main_aave -   row embeds shape: (98, 10, 2432)
11/20/2024 16:05:15 - INFO - main_aave -   seq embeds shape: (98, 10, 2432)


saved file output/batch_9_embeddings
