In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from transformers import BertTokenizerFast
from src.model.data_loading import embed_inputs
from src.config import config, MODEL_CONFIG
import numpy as np
from typing import List
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import torch
from torch import Tensor

In [None]:
def embed_input(text, tokenizer):
    # Truncation = True as bert can only take inputs of max 512 tokens.
    # return_tensors = "pt" makes the funciton return PyTorch tensors
    # tokenizer.encode_plus specifically returns a dictionary of values instead of just a list of values
    encoding = tokenizer(
        text, 
        add_special_tokens = True, 
        truncation = True, 
        padding = "max_length", 
        max_length = 512,
        return_attention_mask = True, 
        return_tensors = "pt"
    )
    # input_ids: mapping the words to tokens
    # attention masks: idicates if index is word or padding
    input_ids = encoding['input_ids']
    attention_masks = encoding['attention_mask']
    return input_ids, attention_masks


@timing
def embed_inputs(texts: list, tokenizer) -> tuple[Tensor, Tensor]:
    input_ids = []
    attention_masks = []
    
    pool_obj = ThreadPoolExecutor(max_workers=os.cpu_count())
    ans = pool_obj.map(partial(embed_input, tokenizer=tokenizer), texts)
    input_ids, attention_masks = list(zip(*ans))

    input_ids: Tensor = torch.cat(input_ids, dim=0)
    attention_masks: Tensor = torch.cat(attention_masks, dim=0)
    return input_ids, attention_masks

In [None]:
def get_text_and_labels(dat: pd.DataFrame, 
                        text_col: str = None,
                        label_col: str = None) -> tuple[List, List]:
    if not text_col:
        text_col = MODEL_CONFIG.input_col_name
    if not label_col:
        label_col = MODEL_CONFIG.target_col_name
    texts = dat.loc[:, text_col].tolist()
    labels = dat.loc[:, label_col].tolist()
    return texts, labels

In [None]:
max_encoding_length = 512
tokenizer = BertTokenizerFast.from_pretrained(MODEL_CONFIG.transformer_hugface_id)

In [None]:
dataset = pd.read_parquet(config.data.benzinga.cleaned)
# Dummy column
dataset["text_length"] = dataset["parsed_body"].map(lambda x: len(x))

In [None]:
texts, labels = get_text_and_labels(dat=dataset, 
                                    split=None, 
                                    text_col="parsed_body", 
                                    label_col="text_length")
input_ids, masks = embed_inputs(texts, 
                             tokenizer)

In [None]:
# Assert taht the max length of input_ids is 512 -> where is this configured??
np.max([len(x) for x in input_ids])

func:'embed_inputs' took: 14.4095 sec



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [None]:
# Needs to be done this way in order to store lists inside a pd.DataFrame cell
input_ids = pd.Series(index=dataset.index, data=list(input_ids))
masks = pd.Series(index=dataset.index, data=list(masks))
dataset["input_id"] = input_ids
dataset["mask"] = masks
dataset.to_parquet(config.data.benzinga.cleaned)


In [None]:
# encoding_matrix = np.ndarray(shape=(2*len(input_ids)+1, max_encoding_length))
# encoding_matrix[:, 0] = dataset.index
# encoding_matrix[:, 1:(max_encoding_length+1)] == input_ids
# encoding_matrix[:, (max_encoding_length+1):] == masks

In [None]:
# np.save(file=encoding_matrix_path, arr=encoding_matrix)

In [None]:
# def get_encoding(encoding_matrix_path: str):
#     encoding_matrix = np.load(file=encoding_matrix_path, arr=encoding_matrix)
#     index = encoding_matrix[:, 0]
#     input_ids = encoding_matrix[:, 1:(max_encoding_length+1)]
#     masks = encoding_matrix[:, (max_encoding_length+1):]
#     return index, input_ids, masks