In [1]:
import sys
import os
from logging import getLogger, ERROR
import time 

notebook_dir = os.getcwd()
parent_dir = os.path.dirname(notebook_dir)
# Fix module imports
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Disable Hugging Face warnings
getLogger("transformers.modeling_utils").setLevel(ERROR)

In [2]:
from model.qgpt2_models import QGPT2ClassificationModel
import torch
from pandas import read_csv
from datasets import Dataset
from sklearn.metrics import f1_score
from sklearn.model_selection import GridSearchCV
import numpy
from transformers import GPT2Model, GPT2Tokenizer

fhe_model = QGPT2ClassificationModel.from_pretrained("./saved_model", n_bits=8,use_cache=False, num_labels=3)
                                             
gpt2_model = GPT2Model.from_pretrained("./saved_model", num_labels=3)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True, max_length=128)

tokenizer.pad_token = tokenizer.eos_token
gpt2_model.config.pad_token_id = gpt2_model.config.eos_token_id
fhe_model.config.pad_token_id = fhe_model.config.eos_token_id

In [4]:
df = read_csv("../data/Tweets.csv")
df['airline_sentiment'] = df['airline_sentiment'].replace(["negative", "neutral", "positive"], [0, 1, 2])

dataset = Dataset.from_pandas(df)
dataset = dataset.select_columns(["text", "airline_sentiment"])
dataset = dataset.rename_column("airline_sentiment", "label")

ds_dict = dataset.train_test_split(test_size= 0.1, seed=42)
train_ds = ds_dict["train"]
eval_ds = ds_dict["test"]

In [5]:
# Function that transforms a list of texts to their representation
# learned by the transformer.

def get_hidden_states(
    inputs: list,
    transformer_model,
    tokenizer: GPT2Tokenizer,
    device: str = "cuda",
):
    # Tokenize each text in the list one by one
    tokenized = map(lambda x: tokenizer.encode(x, return_tensors="pt"), inputs)

    # Send the model to the device
    transformer_model = transformer_model.to(device)
    output_hidden_states_list = []

    for tokenized_x in tokenized:
        # Pass the tokens through the transformer model and get the hidden states
        # Only keep the last hidden layer state for now
        output_hidden_states = transformer_model(tokenized_x.to(device), output_hidden_states=True).hidden_states[-1]
        # Average over the tokens axis to get a representation at the text level.
        output_hidden_states = output_hidden_states.mean(dim=1)
        output_hidden_states = output_hidden_states.detach().cpu().numpy()
        output_hidden_states_list.append(output_hidden_states)

    return numpy.concatenate(output_hidden_states_list, axis=0)

hidden_states = get_hidden_states(train_ds["text"], gpt2_model, tokenizer)
x_test_states = get_hidden_states(eval_ds["text"], gpt2_model, tokenizer)

In [16]:
numpy.savetxt("train_hidden_states.csv", hidden_states, delimiter=",")
numpy.savetxt("test_hidden_states.csv", x_test_states, delimiter=",")

In [6]:
classifier = torch.nn.Linear(768, 3, bias=False)


print(f"Best score: {grid_search.best_score_}")
print(f"Best hyper-parameters: {grid_search.best_params_}")

Best score: 0.6052671523982999
Best hyper-parameters: {'max_depth': 1, 'n_bits': 2, 'n_estimators': 50}


: 

In [7]:
# Extract best model
best_model = grid_search.best_estimator_

best_model.compile(x_test_states)

# Compute the metrics for each class
start = time.perf_counter()
y_proba = best_model.predict_proba(x_test_states, fhe="execute")
end = time.perf_counter()
y_test = eval_ds["label"]

# Compute the accuracy
y_pred = numpy.argmax(y_proba, axis=1)

f1 = f1_score(
    y_test, y_pred, average="macro"
)

f1s = f1_score(
    y_test, y_pred, average=None
)

print(f"Run time: {end - start:.4f} seconds")
print(f"Macro F1: {f1:.4f}")
print(f"F1 score for negative class: " f"{f1s[0]:.4f}")
print(f"F1 score for neutral class: " f"{f1s[1]:.4f}")
print(f"F1 score for positive class: " f"{f1s[2]:.4f}")