In [None]:
!pip install transformers
!pip install plotly
!pip install cufflinks

In [None]:
import gym
import numpy as np
from gym import spaces
from transformers import BertTokenizerFast
import pandas as pd
import torch
from transformers import BertTokenizerFast, DistilBertForSequenceClassification
from torch.distributions import Categorical
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
class LabelingEnv(gym.Env):
  def __init__(self, instances, labels):
    super(LabelingEnv, self).__init__()
    self.instances = instances
    self.labels = labels
    self.current_instance = 0
    self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    encoded = self.tokenizer([self.instances[self.current_instance]], return_tensors='pt', padding='max_length', truncation=True, max_length=128, return_token_type_ids=False)

    #define the output of the model
    self.action_space = spaces.Discrete(2)
    self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1, 128))

  def step(self, action):
    reward = 1 if action == self.labels[self.current_instance] else -1
    self.current_instance += 1
    done = self.current_instance == len(self.instances)
    if done:
      next_state = None
    else:
        encoded = self.tokenizer([self.instances[self.current_instance]], return_tensors='pt', padding='max_length', truncation=True, max_length=128, return_token_type_ids=False)
        next_state = { 'input_ids': encoded['input_ids'], 'attention_mask': encoded['attention_mask'] }
    return next_state, reward, done


  def reset(self):
    self.current_instance = 0
    encoded = self.tokenizer([self.instances[self.current_instance]], return_tensors='pt', padding='max_length', truncation=True, max_length=128, return_token_type_ids=False)
    return { 'input_ids': encoded['input_ids'], 'attention_mask': encoded['attention_mask'] }

In [None]:
from google.colab import drive
drive.mount('/content/drive')

  self._read_thread.setDaemon(True)


Mounted at /content/drive


In [None]:
#loading model from BERT
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased') # 2 labels: Slang, No Slang
for param in model.base_model.parameters():
    param.requires_grad = False
tokenizer = BertTokenizerFast.from_pretrained('distilbert-base-uncased')

#set up an optimizer
optimizer = Adam(model.parameters(), lr=1e-5)

df = pd.read_csv('/content/drive/MyDrive/BERT Models/Dataset/unbiasedDataTrain.csv') #the file directory
df.drop_duplicates(subset = ['sentence'], inplace = True)

instances = df['sentence'].tolist()
labels = df['label'].tolist()

#custom envinronment
env = LabelingEnv(instances, labels)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.we

In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

kf = KFold(n_splits=5, shuffle=True, random_state=42)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

all_rewards = []

for fold, (train_index, test_index) in enumerate(kf.split(instances)):
    print(f'Starting Fold {fold+1}...')
    train_instances = [instances[i] for i in train_index]
    train_labels = [labels[i] for i in train_index]
    test_instances = [instances[i] for i in test_index]
    test_labels = [labels[i] for i in test_index]

    # Initialize model and optimizer for each fold
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased') # 2 labels: Slang, No Slang
    for param in model.base_model.parameters():
        param.requires_grad = False
    model.to(device)
    optimizer = Adam(model.parameters(), lr=1e-5)

    # Create environment with training data
    env = LabelingEnv(train_instances, train_labels)

    n = 100 #number of epochs
    model.train()
    fold_rewards = []
    for epoch in tqdm(range(n), desc = 'Epochs'):
        state = env.reset()
        done = False
        epoch_rewards = []
        pbar = tqdm(total=len(env.instances), desc=f'Epoch {epoch + 1}', leave=False)
        while not done:
            if state is not None:
                state = {k: v.to(device) for k, v in state.items()}
                outputs = model(**state)

                #softmax for model output
                probs = torch.nn.functional.softmax(outputs.logits, dim=-1)

                #sampling action from the probabilities
                dist = Categorical(probs[0])
                action = dist.sample()

                #train in the environment
                new_state, reward, done = env.step(action.item())
                epoch_rewards.append(reward)

                loss = -dist.log_prob(action) * reward

                #backpropagation
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                #updating the state
                state = new_state if new_state is not None else None

                pbar.update(1)
            else:
                break
        pbar.close()
        fold_rewards.append(np.sum(epoch_rewards))
        print(f'\nEpoch {epoch + 1}: Total rewards {np.sum(epoch_rewards)}')

    print(f'Validating on Fold {fold+1}...')
    env = LabelingEnv(test_instances, test_labels)
    model.eval()
    preds = []
    with torch.no_grad():
        for instance in test_instances:
            encoded = tokenizer([instance], return_tensors='pt', padding='max_length', truncation=True, max_length=128, return_token_type_ids=False)
            encoded = {k: v.to(device) for k, v in encoded.items()}
            outputs = model(**encoded)
            _, predicted = torch.max(outputs.logits, dim=1)
            preds.append(predicted.item())
    all_rewards.append(fold_rewards)
    accuracy = accuracy_score(test_labels, preds)
    precision = precision_score(test_labels, preds)
    recall = recall_score(test_labels, preds)
    f1 = f1_score(test_labels, preds)

    print(f'Validation results for Fold {fold+1}: Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1-score: {f1}\n')


Starting Fold 1...


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.we


Epoch 1: Total rewards 34



Epoch 2:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 2:   2%|▏         | 12/640 [00:00<00:05, 112.13it/s][A
Epoch 2:   4%|▍         | 24/640 [00:00<00:05, 109.13it/s][A
Epoch 2:   6%|▌         | 36/640 [00:00<00:05, 110.35it/s][A
Epoch 2:   8%|▊         | 48/640 [00:00<00:05, 111.70it/s][A
Epoch 2:   9%|▉         | 60/640 [00:00<00:05, 110.99it/s][A
Epoch 2:  11%|█▏        | 72/640 [00:00<00:05, 113.01it/s][A
Epoch 2:  13%|█▎        | 84/640 [00:00<00:04, 113.70it/s][A
Epoch 2:  15%|█▌        | 96/640 [00:00<00:04, 113.91it/s][A
Epoch 2:  17%|█▋        | 108/640 [00:00<00:04, 115.56it/s][A
Epoch 2:  19%|█▉        | 120/640 [00:01<00:04, 112.67it/s][A
Epoch 2:  21%|██        | 132/640 [00:01<00:04, 113.99it/s][A
Epoch 2:  22%|██▎       | 144/640 [00:01<00:04, 113.72it/s][A
Epoch 2:  24%|██▍       | 156/640 [00:01<00:04, 113.89it/s][A
Epoch 2:  26%|██▋       | 168/640 [00:01<00:04, 114.05it/s][A
Epoch 2:  28%|██▊       | 180/640 [00:01<00:04, 114.23it/s][A
Epoch 2:  


Epoch 2: Total rewards -12



Epoch 3:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 3:   2%|▏         | 13/640 [00:00<00:04, 127.39it/s][A
Epoch 3:   4%|▍         | 26/640 [00:00<00:05, 120.14it/s][A
Epoch 3:   6%|▌         | 39/640 [00:00<00:05, 114.58it/s][A
Epoch 3:   8%|▊         | 51/640 [00:00<00:05, 114.27it/s][A
Epoch 3:  10%|▉         | 63/640 [00:00<00:05, 114.52it/s][A
Epoch 3:  12%|█▏        | 75/640 [00:00<00:04, 113.45it/s][A
Epoch 3:  14%|█▎        | 87/640 [00:00<00:04, 111.01it/s][A
Epoch 3:  15%|█▌        | 99/640 [00:00<00:04, 108.60it/s][A
Epoch 3:  17%|█▋        | 111/640 [00:00<00:04, 111.84it/s][A
Epoch 3:  19%|█▉        | 123/640 [00:01<00:04, 113.37it/s][A
Epoch 3:  21%|██        | 135/640 [00:01<00:04, 112.88it/s][A
Epoch 3:  23%|██▎       | 147/640 [00:01<00:04, 110.01it/s][A
Epoch 3:  25%|██▍       | 159/640 [00:01<00:04, 109.07it/s][A
Epoch 3:  27%|██▋       | 171/640 [00:01<00:04, 111.14it/s][A
Epoch 3:  29%|██▊       | 183/640 [00:01<00:04, 112.16it/s][A
Epoch 3:  


Epoch 3: Total rewards -28



Epoch 4:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 4:   2%|▏         | 13/640 [00:00<00:05, 124.63it/s][A
Epoch 4:   4%|▍         | 26/640 [00:00<00:05, 120.12it/s][A
Epoch 4:   6%|▌         | 39/640 [00:00<00:04, 120.52it/s][A
Epoch 4:   8%|▊         | 52/640 [00:00<00:04, 117.66it/s][A
Epoch 4:  10%|█         | 64/640 [00:00<00:04, 117.80it/s][A
Epoch 4:  12%|█▏        | 76/640 [00:00<00:04, 117.09it/s][A
Epoch 4:  14%|█▍        | 88/640 [00:00<00:04, 117.20it/s][A
Epoch 4:  16%|█▌        | 101/640 [00:00<00:04, 118.99it/s][A
Epoch 4:  18%|█▊        | 113/640 [00:00<00:04, 114.76it/s][A
Epoch 4:  20%|█▉        | 125/640 [00:01<00:04, 115.06it/s][A
Epoch 4:  21%|██▏       | 137/640 [00:01<00:04, 115.96it/s][A
Epoch 4:  23%|██▎       | 149/640 [00:01<00:04, 116.17it/s][A
Epoch 4:  25%|██▌       | 161/640 [00:01<00:04, 116.57it/s][A
Epoch 4:  27%|██▋       | 173/640 [00:01<00:04, 115.09it/s][A
Epoch 4:  29%|██▉       | 185/640 [00:01<00:03, 116.07it/s][A
Epoch 4: 


Epoch 4: Total rewards -34



Epoch 5:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 5:   1%|▏         | 8/640 [00:00<00:08, 78.59it/s][A
Epoch 5:   2%|▎         | 16/640 [00:00<00:08, 77.98it/s][A
Epoch 5:   4%|▍         | 24/640 [00:00<00:08, 74.98it/s][A
Epoch 5:   5%|▌         | 35/640 [00:00<00:06, 86.47it/s][A
Epoch 5:   7%|▋         | 47/640 [00:00<00:06, 97.82it/s][A
Epoch 5:   9%|▉         | 59/640 [00:00<00:05, 103.44it/s][A
Epoch 5:  11%|█         | 71/640 [00:00<00:05, 107.60it/s][A
Epoch 5:  13%|█▎        | 83/640 [00:00<00:05, 109.81it/s][A
Epoch 5:  15%|█▍        | 95/640 [00:00<00:04, 111.10it/s][A
Epoch 5:  17%|█▋        | 108/640 [00:01<00:04, 114.56it/s][A
Epoch 5:  19%|█▉        | 120/640 [00:01<00:04, 115.47it/s][A
Epoch 5:  21%|██        | 132/640 [00:01<00:04, 112.50it/s][A
Epoch 5:  22%|██▎       | 144/640 [00:01<00:04, 114.12it/s][A
Epoch 5:  24%|██▍       | 156/640 [00:01<00:04, 115.46it/s][A
Epoch 5:  26%|██▋       | 169/640 [00:01<00:03, 118.08it/s][A
Epoch 5:  28%|██▊


Epoch 5: Total rewards 34



Epoch 6:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 6:   2%|▏         | 12/640 [00:00<00:05, 119.04it/s][A
Epoch 6:   4%|▍         | 24/640 [00:00<00:05, 112.87it/s][A
Epoch 6:   6%|▌         | 36/640 [00:00<00:05, 114.85it/s][A
Epoch 6:   8%|▊         | 48/640 [00:00<00:05, 115.43it/s][A
Epoch 6:   9%|▉         | 60/640 [00:00<00:05, 115.58it/s][A
Epoch 6:  11%|█▏        | 72/640 [00:00<00:05, 112.97it/s][A
Epoch 6:  13%|█▎        | 84/640 [00:00<00:04, 113.58it/s][A
Epoch 6:  15%|█▌        | 97/640 [00:00<00:04, 116.52it/s][A
Epoch 6:  17%|█▋        | 109/640 [00:00<00:04, 116.53it/s][A
Epoch 6:  19%|█▉        | 121/640 [00:01<00:04, 115.69it/s][A
Epoch 6:  21%|██        | 133/640 [00:01<00:04, 115.65it/s][A
Epoch 6:  23%|██▎       | 145/640 [00:01<00:04, 115.55it/s][A
Epoch 6:  25%|██▍       | 157/640 [00:01<00:04, 115.91it/s][A
Epoch 6:  26%|██▋       | 169/640 [00:01<00:04, 115.07it/s][A
Epoch 6:  28%|██▊       | 181/640 [00:01<00:04, 112.65it/s][A
Epoch 6:  


Epoch 6: Total rewards -28



Epoch 7:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 7:   1%|▏         | 9/640 [00:00<00:07, 89.71it/s][A
Epoch 7:   3%|▎         | 18/640 [00:00<00:07, 86.21it/s][A
Epoch 7:   4%|▍         | 27/640 [00:00<00:06, 87.72it/s][A
Epoch 7:   6%|▌         | 37/640 [00:00<00:06, 89.15it/s][A
Epoch 7:   7%|▋         | 47/640 [00:00<00:06, 91.05it/s][A
Epoch 7:   9%|▉         | 57/640 [00:00<00:06, 90.14it/s][A
Epoch 7:  10%|█         | 67/640 [00:00<00:06, 91.02it/s][A
Epoch 7:  12%|█▏        | 77/640 [00:00<00:06, 89.79it/s][A
Epoch 7:  14%|█▎        | 87/640 [00:00<00:06, 90.13it/s][A
Epoch 7:  15%|█▌        | 97/640 [00:01<00:06, 89.02it/s][A
Epoch 7:  17%|█▋        | 106/640 [00:01<00:06, 84.44it/s][A
Epoch 7:  18%|█▊        | 115/640 [00:01<00:06, 82.36it/s][A
Epoch 7:  19%|█▉        | 124/640 [00:01<00:06, 82.04it/s][A
Epoch 7:  21%|██        | 133/640 [00:01<00:06, 83.70it/s][A
Epoch 7:  22%|██▏       | 142/640 [00:01<00:05, 85.42it/s][A
Epoch 7:  24%|██▍       | 15


Epoch 7: Total rewards 14



Epoch 8:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 8:   2%|▏         | 12/640 [00:00<00:05, 118.30it/s][A
Epoch 8:   4%|▍         | 24/640 [00:00<00:05, 117.60it/s][A
Epoch 8:   6%|▌         | 36/640 [00:00<00:05, 116.39it/s][A
Epoch 8:   8%|▊         | 48/640 [00:00<00:05, 114.11it/s][A
Epoch 8:   9%|▉         | 60/640 [00:00<00:05, 113.36it/s][A
Epoch 8:  11%|█▏        | 72/640 [00:00<00:04, 114.77it/s][A
Epoch 8:  13%|█▎        | 84/640 [00:00<00:04, 114.60it/s][A
Epoch 8:  15%|█▌        | 97/640 [00:00<00:04, 117.14it/s][A
Epoch 8:  17%|█▋        | 109/640 [00:00<00:04, 116.69it/s][A
Epoch 8:  19%|█▉        | 121/640 [00:01<00:04, 115.84it/s][A
Epoch 8:  21%|██        | 134/640 [00:01<00:04, 117.69it/s][A
Epoch 8:  23%|██▎       | 147/640 [00:01<00:04, 118.98it/s][A
Epoch 8:  25%|██▍       | 159/640 [00:01<00:04, 115.22it/s][A
Epoch 8:  27%|██▋       | 171/640 [00:01<00:04, 112.53it/s][A
Epoch 8:  29%|██▊       | 183/640 [00:01<00:04, 114.02it/s][A
Epoch 8:  


Epoch 8: Total rewards 72



Epoch 9:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 9:   2%|▏         | 12/640 [00:00<00:05, 116.69it/s][A
Epoch 9:   4%|▍         | 24/640 [00:00<00:05, 116.37it/s][A
Epoch 9:   6%|▌         | 36/640 [00:00<00:05, 114.90it/s][A
Epoch 9:   8%|▊         | 48/640 [00:00<00:05, 114.11it/s][A
Epoch 9:   9%|▉         | 60/640 [00:00<00:05, 115.98it/s][A
Epoch 9:  11%|█▏        | 72/640 [00:00<00:04, 114.27it/s][A
Epoch 9:  13%|█▎        | 84/640 [00:00<00:04, 114.47it/s][A
Epoch 9:  15%|█▌        | 96/640 [00:00<00:04, 114.37it/s][A
Epoch 9:  17%|█▋        | 108/640 [00:00<00:04, 110.65it/s][A
Epoch 9:  19%|█▉        | 120/640 [00:01<00:04, 112.40it/s][A
Epoch 9:  21%|██        | 132/640 [00:01<00:04, 113.99it/s][A
Epoch 9:  23%|██▎       | 145/640 [00:01<00:04, 116.49it/s][A
Epoch 9:  25%|██▍       | 157/640 [00:01<00:04, 114.95it/s][A
Epoch 9:  26%|██▋       | 169/640 [00:01<00:04, 114.63it/s][A
Epoch 9:  28%|██▊       | 181/640 [00:01<00:04, 114.07it/s][A
Epoch 9:  


Epoch 9: Total rewards 28



Epoch 10:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 10:   2%|▏         | 12/640 [00:00<00:05, 112.44it/s][A
Epoch 10:   4%|▍         | 24/640 [00:00<00:05, 111.45it/s][A
Epoch 10:   6%|▌         | 36/640 [00:00<00:05, 110.67it/s][A
Epoch 10:   8%|▊         | 48/640 [00:00<00:05, 110.28it/s][A
Epoch 10:   9%|▉         | 60/640 [00:00<00:05, 112.42it/s][A
Epoch 10:  11%|█▏        | 72/640 [00:00<00:05, 113.29it/s][A
Epoch 10:  13%|█▎        | 84/640 [00:00<00:04, 111.46it/s][A
Epoch 10:  15%|█▌        | 96/640 [00:00<00:04, 113.65it/s][A
Epoch 10:  17%|█▋        | 108/640 [00:00<00:04, 108.40it/s][A
Epoch 10:  19%|█▊        | 119/640 [00:01<00:04, 106.52it/s][A
Epoch 10:  20%|██        | 131/640 [00:01<00:04, 107.95it/s][A
Epoch 10:  22%|██▏       | 143/640 [00:01<00:04, 108.88it/s][A
Epoch 10:  24%|██▍       | 155/640 [00:01<00:04, 109.56it/s][A
Epoch 10:  26%|██▌       | 166/640 [00:01<00:04, 108.42it/s][A
Epoch 10:  28%|██▊       | 177/640 [00:01<00:04, 108.37it/


Epoch 10: Total rewards 34



Epoch 11:   0%|          | 0/640 [00:00<?, ?it/s][A
Epoch 11:   2%|▏         | 12/640 [00:00<00:05, 113.17it/s][A
Epoch 11:   4%|▍         | 24/640 [00:00<00:05, 109.45it/s][A
Epoch 11:   5%|▌         | 35/640 [00:00<00:05, 109.41it/s][A
Epoch 11:   7%|▋         | 47/640 [00:00<00:05, 110.53it/s][A
Epoch 11:   9%|▉         | 59/640 [00:00<00:05, 110.28it/s][A
Epoch 11:  11%|█         | 71/640 [00:00<00:05, 111.61it/s][A
Epoch 11:  13%|█▎        | 83/640 [00:00<00:04, 113.14it/s][A
Epoch 11:  15%|█▍        | 95/640 [00:00<00:04, 111.48it/s][A
Epoch 11:  17%|█▋        | 107/640 [00:00<00:04, 108.54it/s][A
Epoch 11:  18%|█▊        | 118/640 [00:01<00:04, 107.82it/s][A
Epoch 11:  20%|██        | 129/640 [00:01<00:04, 107.36it/s][A
Epoch 11:  22%|██▏       | 140/640 [00:01<00:04, 106.14it/s][A
Epoch 11:  24%|██▎       | 151/640 [00:01<00:04, 104.75it/s][A
Epoch 11:  25%|██▌       | 162/640 [00:01<00:04, 101.87it/s][A
Epoch 11:  27%|██▋       | 173/640 [00:01<00:04, 103.38it/

KeyboardInterrupt: ignored

In [None]:
import plotly.graph_objects as go
import cufflinks as cf
import pandas as pd

#assuming all_rewards is a list of lists where each sublist is rewards of one fold
all_rewards_df = pd.DataFrame(all_rewards).T #transpose to have each fold as a column

moving_avg_rewards = all_rewards_df.rolling(window=10).mean()

fig = go.Figure()
for fold in range(5):
    fig.add_trace(go.Scatter(x=list(range(len(moving_avg_rewards))),
                             y=moving_avg_rewards[fold],
                             mode='lines',
                             name=f'Fold {fold+1}'))

fig.update_layout(title='Moving Average Rewards per Epoch for each fold',
                   xaxis_title='Epoch',
                   yaxis_title='Moving Average Rewards')

fig.show()

In [None]:
model.save_pretrained('/content/drive/MyDrive/BERT Models/BERT RL/model')
tokenizer.save_pretrained('/content/drive/MyDrive/BERT Models/BERT RL/tokenizer')
import pickle

with open("/content/drive/MyDrive/BERT Models/BERT RL/instances", "wb") as f:
    pickle.dump(instances, f)

with open("/content/drive/MyDrive/BERT Models/BERT RL/labels", "wb") as f:
    pickle.dump(labels, f)


In [None]:
from transformers import BertTokenizerFast, DistilBertForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
import pandas as pd
import torch

model_dir = '/content/drive/MyDrive/BERT Models/BERT RL/model'
tokenizer_dir = '/content/drive/MyDrive/BERT Models/BERT RL/tokenizer'

model = DistilBertForSequenceClassification.from_pretrained(model_dir)
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_dir)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

test_df = pd.read_csv('/content/drive/MyDrive/BERT Models/Dataset/unbiasedDataTest.csv')
test_instances = test_df['sentence'].tolist()
test_labels = test_df['label'].tolist()

env = LabelingEnv(test_instances, test_labels)
model.eval()
preds = []
with torch.no_grad():
    for instance in test_instances:
        encoded = tokenizer([instance], return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        encoded = {k: v.to(device) for k, v in encoded.items()}
        outputs = model(**encoded)
        _, predicted = torch.max(outputs.logits, dim=1)
        preds.append(predicted.item())

accuracy = accuracy_score(test_labels, preds)
precision = precision_score(test_labels, preds)
recall = recall_score(test_labels, preds)
f1 = f1_score(test_labels, preds)

print(f'Test results: Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1-score: {f1}\n')