In [None]:
import pandas as pd
from transformers import XLMRobertaModel, XLMRobertaTokenizer
import torch
from torch import Tensor

tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base", pad_token="<pad>")
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

In [None]:
train_data = pd.read_csv("../../data/train/train.csv")
test_data = pd.read_csv("../../data/test/final_test_pairs.csv")

print(train_data.columns)

In [None]:
def pad_embedding(text: Tensor, target_length: int) -> Tensor:
    """
    Pad the input text with ones to reach the target length of a sequence (sentence).
    
    Args:
        text (Tensor): The input text tensor.
        target_length (int): The desired length of the padded text tensor.
        
    Returns:
        Tensor: The padded text tensor.
    """
    text = torch.tensor(text)

    # add ones to the end of the text (1 is the padding token)
    text = torch.nn.functional.pad(text, (0, target_length - text.shape[1]), 'constant', 1)

    return text

def tokenize_and_shorten_sentence(text: str) -> Tensor:
    """
    Tokenize the input text and shorten it to 256 tokens.

    Args:
        text (str): The input text.

    Returns:
        Tensor: The tokenized and shortened text tensor.
    """

    tokenized_text = tokenizer(text, return_tensors="pt", padding=False, truncation=False) #padding = False and truncation = False to get the exact length of the text
    if tokenized_text["input_ids"].shape[1] > 256:
        # TODO: decide where to truncate the text, meaning how many tokens to keep from the head and how many from the tail
        # note: the model has a max length of 512 tokens and we need to keep the [CLS] ? and [SEP] tokens
        tokenized_text["shorten_ids"] = torch.cat((tokenized_text["input_ids"][:, :200], tokenized_text["input_ids"][:, -55:]), dim=1)
    else:
        tokenized_text["shorten_ids"] = tokenizer(text, return_tensors="pt", padding="max_length", truncation=False, max_length=256)["input_ids"]
        
    return tokenized_text["shorten_ids"]

In [None]:
NUM = 1 #note: len(train_data)
#train_embeddings = torch.empty((NUM, 512, 768)) # note: if we want to keep the whole sequence's embeddings
train_embeddings = torch.empty((NUM, 768)) # note: take into account the [CLS] token only as it represents the whole sentence

sep_token = torch.tensor([tokenizer.sep_token_id]).unsqueeze(0)

for i in range(len(train_data[:NUM])):
    text1 = train_data["text1"][i]
    text2 = train_data["text2"][i]
    text1_tokenized = tokenize_and_shorten_sentence(text1)
    text2_tokenized = tokenize_and_shorten_sentence(text2)

    with torch.no_grad():
        text = torch.cat((text1_tokenized, sep_token, text2_tokenized), dim=1)
        # TODO: investigate the attention mask
        text = {"input_ids": text, "attention_mask": torch.ones(text.shape)}
        outputs = model(**text)
        embeddings = outputs.last_hidden_state
        embeddings = pad_embedding(embeddings, 512)

        cls_token_embedding = outputs.pooler_output
        train_embeddings[i] = cls_token_embedding