In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
import random
import collections


class TextGenerationEnvironment:
    def __init__(self, model_name_or_path, reward_fn, max_length=20):
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
        self.max_length = max_length
        self.current_text = ""
        self.current_length = 0
        self.reward_fn = reward_fn

    def generate_text(self, input_text, max_length=None):
        if max_length is None:
            max_length = self.max_length
        input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
        output = self.model.generate(input_ids=input_ids, max_length=max_length)
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return generated_text

    def get_tokenizer(self):
        return self.tokenizer

    def reset(self):
        self.current_text = ""
        self.current_length = 0
        return ""

    def step(self, action):
        action_token = self.tokenizer.decode([action])
        self.current_text += action_token

        obs = self.current_text[-self.max_length:]
        obs = obs if len(obs) > 0 else ' '  # Ensure the observation is never empty

        reward = self.reward_fn(self.current_text)
        self.current_length += 1
        done = self.current_length >= self.max_length

        # Stop generating text if a certain number of unique tokens are generated
        unique_tokens = len(set(self.tokenizer.encode(self.current_text)))
        if unique_tokens > self.max_length * 0.8:  # Adjust the threshold according to your needs
            done = True

        return obs, reward, done, {}


class ModifiedGPT(nn.Module):
    def __init__(self, model_name_or_path, num_actions=512):
        super(ModifiedGPT, self).__init__()
        config = GPT2Config.from_pretrained(model_name_or_path)
        self.gpt = GPT2LMHeadModel.from_pretrained(model_name_or_path, config=config)
        self.action_layer = nn.Linear(config.n_embd, num_actions)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        gpt_outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = gpt_outputs[0]
        # Calculate attention weights using the additional component
        action = self.action_layer(hidden_states)
        action_probs = self.softmax(action)

        outputs = (action_probs,) + gpt_outputs[1:]

        if labels is not None:
            # Calculate the loss with the labels provided
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(outputs[0].view(-1, self.config.vocab_size), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs


class RewardFunction:
    
    def __init__(self, reference_text, model_name_or_path):
        self.reference_tokens = GPT2Tokenizer.from_pretrained(model_name_or_path).encode(reference_text, return_tensors='pt')
        self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
        self.model_name_or_path = model_name_or_path
        self.vectorizer = TfidfVectorizer()
        self.reference_embedding = self.vectorizer.fit_transform([reference_text]).toarray()[0]

    def __call__(self, generated_text):
        generated_tokens = GPT2Tokenizer.from_pretrained(self.model_name_or_path).encode(generated_text, return_tensors='pt')
        # Reward based on cosine similarity
        generated_embedding = self.vectorizer.transform([generated_text]).toarray()[0]
        similarity = cosine_similarity([self.reference_embedding], [generated_embedding])[0][0]
        reward += similarity * 5  # Adjust the weight according to your needs

        with torch.no_grad():
            
            outputs = self.model(input_ids=generated_tokens, labels=generated_tokens)
            loss = outputs[0]

        perplexity = torch.exp(loss)
        reward = -perplexity.item()

        # Penalize repeated tokens in the generated text
        token_counts = collections.Counter(generated_tokens.tolist()[0])
        repeated_tokens = sum(count for count in token_counts.values() if count > 1)
        reward -= 10 * repeated_tokens  # Increase the penalty value
        return reward

class PPOTrainer:

    def __init__(self, env, model, reward_fn, lr=1e-4, betas=(0.9, 0.999), eps=1e-5, gamma=0.99, clip_param=0.2, value_loss_coef=0.5, entropy_coef=0.01, num_epochs=10, batch_size=64):
        self.env = env
        self.model = model
        self.reward_fn = reward_fn
        self.gamma = gamma
        self.clip_param = clip_param
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.num_epochs = num_epochs  # Add this line
        self.batch_size = batch_size

        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)

    def update(self, obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch):
        # This method should be implemented to update the model during training
        # Use obs_batch, action_batch, action_prob_batch, advantage_batch, and return_batch
        # to calculate the loss and update the model parameters
        pass
    
    def select_action(action_probs, temperature=1.2):
        action_probs = np.asarray(action_probs)
        scaled_probs = np.exp(np.log(action_probs) / temperature)
        scaled_probs /= np.sum(scaled_probs)
        action = np.random.choice(len(action_probs), p=scaled_probs)
        return action
    
    def train(self, num_steps):
        storage = RolloutStorage(self.env.max_length)
        obs = self.env.reset()
        done = True  # Add this line to initialize the 'done' variable

        for step in range(num_steps):
            if done:
                obs = self.env.reset()

            for t in range(self.env.max_length):
                obs_tensor = torch.tensor(self.env.tokenizer.encode(obs), dtype=torch.long).unsqueeze(0)
                if obs_tensor.shape[-1] == 0:
                    continue

                attention_mask = torch.ones(obs_tensor.shape, dtype=torch.long)  # Create the attention mask
                action_probs_tensor, value_tensor = self.model(obs_tensor, attention_mask=attention_mask)
                action_probs = action_probs_tensor.squeeze(0).detach().numpy()
                # action = np.random.choice(len(action_probs), p=action_probs)
                action = select_action(action_probs, temperature=0.8)

                next_obs, reward, done, _ = self.env.step(action)
                storage.store(obs, action, action_probs, reward, done)

                obs = next_obs

                if done:
                    break

            if not done:
                obs_tensor = torch.tensor(self.env.tokenizer.encode(obs), dtype=torch.long).unsqueeze(0)
                attention_mask = torch.ones(obs_tensor.shape, dtype=torch.long)  # Create the attention mask
                _, value_tensor = self.model(obs_tensor, attention_mask=attention_mask)
                storage.store_last_observation(value_tensor)
            else:
                storage.store_last_observation(torch.tensor(0.0))

            returns = storage.compute_returns(self.gamma)
            advantages = storage.compute_advantages(returns)

            for _ in range(self.num_epochs):
                indices = np.arange(len(storage))
                np.random.shuffle(indices)

                for batch_start in range(0, len(storage), self.batch_size):
                    batch_indices = indices[batch_start: batch_start + self.batch_size]
                    obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch = storage.sample(batch_indices)
                    self.update(obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch)

            storage.clear()
        return self.model
    
class RolloutStorage:
    def __init__(self, max_length):
        self.obs = []
        self.actions = []
        self.action_probs = []
        self.rewards = []
        self.dones = []
        self.max_length = max_length

    def store(self, obs, action, action_prob, reward, done):
        self.obs.append(obs)
        self.actions.append(action)
        self.action_probs.append(action_prob)
        self.rewards.append(reward)
        self.dones.append(done)

    def store_last_observation(self, value):
        self.last_observation = value

    def clear(self):
        self.obs = []
        self.actions = []
        self.action_probs = []
        self.rewards = []
        self.dones = []

    def __len__(self):  # Add this method
        return len(self.obs)

    def compute_returns(self, gamma):
        returns = []
        next_value = self.last_observation
        for t in reversed(range(len(self.rewards))):  # Update this line
            if self.dones[t]:
                next_value = 0
            returns.append(self.rewards[t] + gamma * next_value)
            next_value = returns[-1]
        returns.reverse()
        return returns

    def compute_advantages(self, returns):
        advantages = [ret - prob for ret, prob in zip(returns, self.action_probs)]
        return advantages

    def sample(self, indices):
        obs_batch = [self.obs[i] for i in indices]
        action_batch = [self.actions[i] for i in indices]
        action_prob_batch = [self.action_probs[i] for i in indices]
        advantage_batch = [self.advantages[i] for i in indices]
        return_batch = [self.returns[i] for i in indices]
        return obs_batch, action_batch, action_prob_batch, advantage_batch, return_batch

reference_text = "This is an example of reference text. The model should generate different non-repititive answer"
model_name_or_path = "gpt2"

reward_fn = RewardFunction(reference_text, model_name_or_path)
env = TextGenerationEnvironment(model_name_or_path, reward_fn)
model = ModifiedGPT(model_name_or_path)

num_steps = 1000
trainer = PPOTrainer(env, model, reward_fn)
trained_model = trainer.train(num_steps)

input_text = "This is a test input"
generated_text = env.generate_text(input_text)
print("Generated text:", generated_text)