In [11]:
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import polars as pl
# 
from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL, 
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_SUBTITLE_COL,
    DEFAULT_LABELS_COL,
    DEFAULT_TITLE_COL, 
    DEFAULT_USER_COL, 
)
#
from ebrec.utils._behaviors import (
    create_binary_labels_column, 
    sampling_strategy_wu2019,
    add_known_user_column,
    add_prediction_scores,
    truncate_history, 
)
from ebrec.utils._articles import convert_text2encoding_with_transformers
from ebrec.utils._polars import concat_str_columns, slice_join_dataframes
from ebrec.utils._articles import create_article_id_to_value_mapping
from ebrec.utils._nlp import get_transformers_word_embeddings
#
from ebrec.models.newsrec.dataloader import NRMSDataLoader
from ebrec.models.newsrec.model_config import hparams_nrms
from ebrec.models.newsrec import NRMSModel

import numpy as np
import torch
from ebrec.models.fastformer.dataloader import FastformerDataset
from ebrec.models.fastformer.dataloader import train as train_fastformer
from ebrec.models.fastformer.dataloader import evaluate as evaluate_fastformer
from torch.utils.data import DataLoader
from ebrec.models.fastformer.fastformer import Fastformer
from ebrec.models.fastformer.config import FastFormerConfig

In [12]:
path = Path("/home/data/dataset/origin/ebnerd_demo/")
N_SAMPLES = "n"
COLUMNS = [DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL, DEFAULT_INVIEW_ARTICLES_COL, DEFAULT_CLICKED_ARTICLES_COL,N_SAMPLES]
HISTORY_SIZE = 30
BATCH_SIZE = 100

In [13]:
def ebnerd_from_path(path:Path, history_size:int = 30) -> pl.DataFrame:
    """
    Load ebnerd - function 
    """
    df_history = (
        pl.scan_parquet(path.joinpath("history.parquet"))
        .select(DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL)
        .pipe(
            truncate_history,
            column=DEFAULT_HISTORY_ARTICLE_ID_COL,
            history_size=history_size,
            padding_value=0,
        )
    )
    df_behaviors = (
        pl.scan_parquet(path.joinpath("behaviors.parquet"))
        .with_columns(pl.col(DEFAULT_INVIEW_ARTICLES_COL).list.len().alias(N_SAMPLES))
        .collect()
        .pipe(
            slice_join_dataframes, df2=df_history.collect(), on=DEFAULT_USER_COL, how="left"
        )
    )
    return df_behaviors

In [14]:
df_train = (
    ebnerd_from_path(path.joinpath("train"), history_size=HISTORY_SIZE)
    .select(COLUMNS)
    .pipe(sampling_strategy_wu2019,npratio=4,shuffle=True,with_replacement=True, seed=123)
    .pipe(create_binary_labels_column)
)

In [15]:
df_validation = (
    ebnerd_from_path(path.joinpath("validation"), history_size=HISTORY_SIZE)
    .select(COLUMNS)
    .pipe(create_binary_labels_column)
)
label_lengths = df_validation[DEFAULT_INVIEW_ARTICLES_COL].list.len().to_list()

In [16]:
TRANSFORMER_MODEL_NAME = "/home/data/models/bert-base-multilingual-cased"
TEXT_COLUMNS_TO_USE = [DEFAULT_SUBTITLE_COL, DEFAULT_TITLE_COL]
MAX_TITLE_LENGTH = 30

# LOAD HUGGINGFACE:
transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

# We'll init the word embeddings using the 
word2vec_embedding = get_transformers_word_embeddings(transformer_model)
# 

In [17]:
df_articles = pl.read_parquet(path.joinpath("articles.parquet"))
df_articles, cat_cal = concat_str_columns(df_articles, columns=TEXT_COLUMNS_TO_USE)
df_articles, token_col_title = convert_text2encoding_with_transformers(
    df_articles, transformer_tokenizer, cat_cal, max_length=MAX_TITLE_LENGTH
)
article_mapping = create_article_id_to_value_mapping(df=df_articles, value_col=token_col_title)

In [18]:
train_dataloader = DataLoader(
    FastformerDataset(
        behaviors=df_train,
        history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
        article_dict=article_mapping,
        batch_size=BATCH_SIZE,
        shuffle=False,
    )
)

In [19]:
test_dataloader = DataLoader(
    FastformerDataset(
        behaviors=df_validation,
        history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
        article_dict=article_mapping,
        batch_size=BATCH_SIZE,
        shuffle=False,
    )
)

In [20]:
MODEL_NAME = "FastFormer"
LOG_DIR = f"/home/data/models/{MODEL_NAME}/log"
MODEL_WEIGHTS = f"/home/data/models/{MODEL_NAME}/weights"

In [21]:
model_config = FastFormerConfig()
model = Fastformer(model_config)

In [23]:
num_epochs=5
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

In [24]:
train_fastformer(model=model, \
                 train_dataloader=train_dataloader, \
                 criterion=None, 
                optimizer=optimizer,
                 scheduler=scheduler,
                 num_epochs=num_epochs, \
                 val_dataloader=test_dataloader, \
                 state_dict_path=MODEL_WEIGHTS, \
                 monitor_metric="auc"
                )

Epoch [1/5]:   0%|                                      | 0/249 [00:10<?, ?it/s]


KeyboardInterrupt: 