In [65]:
import torch
from transformers import RobertaTokenizer, RobertaModel
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification, XLMRobertaModel

In [66]:
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

# Tokenize and encode text
def tokenize_and_encode_batch(texts, batch_size=32):
    encoded_batches = []
    for i in range(0, len(texts), batch_size):
        print(f"Encoding batch {i//batch_size+1}/{len(texts)//batch_size+1}")
        batch_texts = texts[i:i+batch_size]
        encoded_batch = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt").to(device)
        encoded_batches.append(encoded_batch)
    return encoded_batches

def get_mean_hidden_states_batch(encoded_batches):
    embeddings = []
    for i, inputs in enumerate(encoded_batches):
        print(f"Processing batch {i+1}/{len(encoded_batches)}")
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state.mean(dim=1).cpu().numpy())  # Move to CPU and store
    return np.vstack(embeddings)

def get_embeddings(df, batch_size=32):
    premises_encoded_batches = tokenize_and_encode_batch(df["premise"].tolist(), batch_size)
    hypotheses_encoded_batches = tokenize_and_encode_batch(df["hypothesis"].tolist(), batch_size)
    
    premises_embeddings = get_mean_hidden_states_batch(premises_encoded_batches)
    hypotheses_embeddings = get_mean_hidden_states_batch(hypotheses_encoded_batches)
    
    return np.concatenate([premises_embeddings, hypotheses_embeddings], axis=1)

In [None]:
data = pd.read_csv("data/train.csv")


In [18]:
df = pd.DataFrame(data)
X = get_embeddings(df)
y = df["label"]
X.shape, y.shape

Encoding batch 1/379
Encoding batch 2/379
Encoding batch 3/379
Encoding batch 4/379
Encoding batch 5/379
Encoding batch 6/379
Encoding batch 7/379
Encoding batch 8/379
Encoding batch 9/379
Encoding batch 10/379
Encoding batch 11/379
Encoding batch 12/379
Encoding batch 13/379
Encoding batch 14/379
Encoding batch 15/379
Encoding batch 16/379
Encoding batch 17/379
Encoding batch 18/379
Encoding batch 19/379
Encoding batch 20/379
Encoding batch 21/379
Encoding batch 22/379
Encoding batch 23/379
Encoding batch 24/379
Encoding batch 25/379
Encoding batch 26/379
Encoding batch 27/379
Encoding batch 28/379
Encoding batch 29/379
Encoding batch 30/379
Encoding batch 31/379
Encoding batch 32/379
Encoding batch 33/379
Encoding batch 34/379
Encoding batch 35/379
Encoding batch 36/379
Encoding batch 37/379
Encoding batch 38/379
Encoding batch 39/379
Encoding batch 40/379
Encoding batch 41/379
Encoding batch 42/379
Encoding batch 43/379
Encoding batch 44/379
Encoding batch 45/379
Encoding batch 46/3

((12120, 1536), (12120,))

In [22]:
pd.DataFrame(X).to_csv("data/embeddings_train_x.csv", index=False)

In [69]:
data = pd.read_csv("data/test.csv")
df = pd.DataFrame(data)
X = get_embeddings(df)
# y = df["label"]
X.shape, y.shape
pd.DataFrame(X).to_csv("data/embeddings_test_x.csv", index=False)


Encoding batch 1/163
Encoding batch 2/163
Encoding batch 3/163
Encoding batch 4/163
Encoding batch 5/163
Encoding batch 6/163
Encoding batch 7/163
Encoding batch 8/163
Encoding batch 9/163
Encoding batch 10/163
Encoding batch 11/163
Encoding batch 12/163
Encoding batch 13/163
Encoding batch 14/163
Encoding batch 15/163
Encoding batch 16/163
Encoding batch 17/163
Encoding batch 18/163
Encoding batch 19/163
Encoding batch 20/163
Encoding batch 21/163
Encoding batch 22/163
Encoding batch 23/163
Encoding batch 24/163
Encoding batch 25/163
Encoding batch 26/163
Encoding batch 27/163
Encoding batch 28/163
Encoding batch 29/163
Encoding batch 30/163
Encoding batch 31/163
Encoding batch 32/163
Encoding batch 33/163
Encoding batch 34/163
Encoding batch 35/163
Encoding batch 36/163
Encoding batch 37/163
Encoding batch 38/163
Encoding batch 39/163
Encoding batch 40/163
Encoding batch 41/163
Encoding batch 42/163
Encoding batch 43/163
Encoding batch 44/163
Encoding batch 45/163
Encoding batch 46/1