<a href="https://colab.research.google.com/github/HamdanXI/nlp_adventure/blob/main/bert-base-uncased-paradetox-with-labels-with-RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
from datasets import load_dataset

dataset = load_dataset("HamdanXI/paradetox_with_labels")

In [30]:
# Tokenization
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [31]:
# Splitting dataset
from datasets import DatasetDict

split = tokenized_datasets["train"].train_test_split(test_size=0.1)
dataset_split = DatasetDict({"train": split["train"], "test": split["test"]})

In [32]:
import gym
from gym import spaces

class TextClassificationEnv(gym.Env):
    def __init__(self, dataset):
        super(TextClassificationEnv, self).__init__()
        self.dataset = dataset["train"]  # We're specifying the split here
        self.current_index = 0
        self.action_space = spaces.Discrete(2) # toxic or neutral
        self.observation_space = spaces.Box(low=0, high=1, shape=(768,)) # Example for BERT's hidden state dimension

    def reset(self):
        self.current_index = 0
        obs = {"text": self.dataset[self.current_index]['text']}  # Adjusted to provide the text directly
        return obs

    def step(self, action):
        true_label = self.dataset[self.current_index]['label']
        reward = 1 if action == true_label else -1
        self.current_index += 1
        done = self.current_index >= len(self.dataset)
        obs = {"text": self.dataset[self.current_index]['text']} if not done else None  # Adjusted to provide the text directly
        return obs, reward, done, {}

    def render(self, mode='human'):
        pass

In [33]:
import torch
import torch.nn.functional as F
from transformers import BertForSequenceClassification, AdamW

class BERTAgent:
    def __init__(self, model_path, tokenizer, device="cuda"):
        self.model = BertForSequenceClassification.from_pretrained(model_path).to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.optimizer = AdamW(self.model.parameters(), lr=1e-5)
        self.model.train()

    def predict(self, state):
        with torch.no_grad():
            inputs = self.tokenizer(state['text'], return_tensors="pt", padding='max_length', truncation=True, max_length=128).to(self.device)
            logits = self.model(**inputs).logits
            probs = F.softmax(logits, dim=1)
            action = torch.multinomial(probs, 1).item()
            return action

    def optimize(self, state, action, reward):
        inputs = self.tokenizer(state['text'], return_tensors="pt", padding='max_length', truncation=True, max_length=128).to(self.device)
        logits = self.model(**inputs).logits
        loss = -F.log_softmax(logits, dim=1)[0, action] * reward  # Policy gradient loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [34]:
from transformers import BertTokenizer

env = TextClassificationEnv(dataset_split)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_agent = BERTAgent("HamdanXI/bert-base-uncased-paradetox_with_labels", tokenizer)

In [None]:
num_episodes = 1000

for episode in range(num_episodes):
    state = env.reset()
    done = False
    while not done:
        action = bert_agent.predict(state)
        next_state, reward, done, _ = env.step(action)
        bert_agent.optimize(state, action, reward)  # Pass the state as well
        state = next_state