In [None]:
import gym 
from gym import spaces 
import numpy as np 
import torch 
from torch.utils.data import Dataset, DataLoader, Subset 
from transformers import AutoTokenizer, AutoModelForSequenceClassification 
from datasets import load_dataset 
import random 
from stable_baselines3 import PPO 
from stable_baselines3.common.vec_env import DummyVecEnv 
from stable_baselines3.common.evaluation import evaluate_policy 

# Set random seed for reproducibility 
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Load SST-2 dataset 
dataset = load_dataset("sst2")

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

class SST2Dataset(Dataset):
    def __init__(self, split):
        self.data = dataset[split]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]["sentence"]
        label = self.data[idx]["label"]
        encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(label)
        }

# Create full datasets and sample subsets
full_train_dataset = SST2Dataset("train")
full_eval_dataset  = SST2Dataset("validation")

train_indices = random.sample(range(len(full_train_dataset)), 96)
eval_indices  = random.sample(range(len(full_eval_dataset)), 32)

train_dataset = Subset(full_train_dataset, train_indices)
eval_dataset  = Subset(full_eval_dataset, eval_indices)

class SST2Environment(gym.Env):
    def __init__(self, dataset):
        super(SST2Environment, self).__init__()
        self.dataset = dataset
        self.current_index = 0
        
        self.action_space = spaces.Discrete(2)  # Binary classification
        self.observation_space = spaces.Box(low=0, high=1, shape=(768,), dtype=np.float32)  # DistilBERT hidden size
        
        self.model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    
    def reset(self):
        self.current_index = 0
        return self._get_observation()
    
    def step(self, action):
        reward = 1.0 if action == self.dataset[self.current_index]["labels"].item() else -1.0
        done = (self.current_index == len(self.dataset) - 1)
        self.current_index = (self.current_index + 1) % len(self.dataset)
        return self._get_observation(), reward, done, {}
    
    def _get_observation(self):
        input_ids = self.dataset[self.current_index]["input_ids"].unsqueeze(0)
        attention_mask = self.dataset[self.current_index]["attention_mask"].unsqueeze(0)
        with torch.no_grad():
            outputs = self.model.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

# Create and wrap the environment
env = SST2Environment(train_dataset)
env = DummyVecEnv([lambda: env])

# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=1000)

# Evaluate the model
eval_env = SST2Environment(eval_dataset)
eval_env = DummyVecEnv([lambda: eval_env])
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)

print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# Optional: Save the model
model.save("ppo_sst2")

print("Training and evaluation completed!")