In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import AutoModelForCausalLM , AutoTokenizer
from gym import Env , spaces
import numpy as np

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


In [2]:
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
policy_model=AutoModelForCausalLM.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/641 [00:00<?, ?B/s]

  return datetime.utcnow().replace(tzinfo=utc)


model.safetensors:   0%|          | 0.00/351M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [3]:
print(tokenizer.vocab_size)
print(policy_model.config.hidden_size)

50257
768


In [4]:
policy_model=policy_model.to("cuda")

In [5]:
#---Environment---
class ChatbotEnv(Env):
  def __init__(self):
    super(ChatbotEnv,self).__init__()
    self.tokenizer = tokenizer
    self.model=policy_model
    self.action_space=spaces.Discrete(self.tokenizer.vocab_size)
    self.observation_space=spaces.Box(low=np.inf,high=np.inf,shape=(self.model.config.hidden_size,),dtype=np.float32)
    self.conversation_history=[]

  def reset(self):
    self.conversation_history =[]
    return self._get_observation()

  def step(self,action):
    response = self._decode_action(action)
    self.conversation_history.append(response)
    reward = self._calculate_reward(response)
    done = len(self.conversation_history) >=5
    obs = self._get_observation()
    return obs , reward, done, {}

  def _decode_action(self,action_token_id):
    return self.tokenizer.decode([action_token_id],skip_special_tokens=True)

  def _get_observation(self):
    if not self.conversation_history:
      return np.zeros(self.observation_space.shape,dtype=np.float32)
    text = " ".join(self.conversation_history)
    inputs = self.tokenizer(text,return_tensors="pt",truncation=True,max_length=512).to("cuda")
    with torch.no_grad():
      outputs = policy_model.transformer(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze(0).detach().cpu().numpy().astype(np.float32).flatten()

  def _calculate_reward(self,response):
    return len(response.split()) / 10.0

In [6]:
#---sample---
env_sample = ChatbotEnv()
obs = env_sample.reset()
print("initial observation shape: ",obs.shape)

initial observation shape:  (768,)


In [7]:
#--------value network (critic)----------
class ValueNetwork(nn.Module):
  def __init__(self,input_dim=768):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(input_dim,256),
        nn.Tanh(),
        nn.Linear(256,1)
    )

  def forward(self,x):
    return self.net(x)

In [8]:
#----PPO Trainer----
class PPOTrainer:
  def __init__(self,policy_model,value_model,tokenizer,lr=1e-5,clip_eps=0.2):
    self.policy=policy_model
    self.value_net = value_model
    self.tokenizer = tokenizer
    self.clip_eps = clip_eps
    self.optimizer = torch.optim.Adam(
        list(self.policy.parameters()) + list(self.value_net.parameters()),
        lr = lr
    )

  def compute_log_probs(self,inputs,actions):
    with torch.no_grad():
      logits = self.policy(**inputs).logits
    dist = torch.distributions.Categorical(logits=logits[:,-1,:])
    return dist.log_prob(actions) , dist

  def train(self,observations , texts,actions,rewards,old_log_probs):
    obs_tensor = torch.tensor(observations,dtype=torch.float32)
    actions = torch.tensor(actions)
    actions = actions.to("cuda")
    rewards = torch.tensor(rewards,dtype=torch.float32)
    old_log_probs = torch.stack(old_log_probs)

    # estimate value
    values = self.value_net(obs_tensor).squeeze()
    advantages = rewards - values.detach()
    advantages = advantages.to("cuda")

    # get new log probs
    inputs = tokenizer(texts , return_tensors="pt",padding=True,truncation=True).to("cuda")
    logits = self.policy(**inputs).logits
    dists = torch.distributions.Categorical(logits=logits[:,-1,:])
    new_log_probs = dists.log_prob(actions)

    # PPO Loss
    ratios = torch.exp(new_log_probs-old_log_probs)
    ratios = ratios.to("cuda")
    surr1 = ratios*advantages
    surr2 = torch.clamp(ratios,1-self.clip_eps,1+self.clip_eps)*advantages
    policy_loss = -torch.min(surr1,surr2).mean()
    print(f"policy_loss : {policy_loss}")
    value_loss = F.mse_loss(values , rewards)
    print(f"value_loss : {value_loss}")
    loss = policy_loss + value_loss
    print(f"loss : {loss}")

    # update
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()


In [9]:
env = ChatbotEnv()
value_net = ValueNetwork()
trainer = PPOTrainer(policy_model , value_net,tokenizer)

In [11]:
# -- Training Loop --
for epoch in range(10):

  observations = []
  texts = []
  actions = []
  rewards = []
  old_log_probs = []

  # rollouts
  obs = env.reset()
  for _ in range(100):
    #if not env.conversation_history:
     # text_input =  np.zeros(env.observation_space.shape,dtype=np.float32)
    text_input = "".join(env.conversation_history)  or "hello"
    inputs = tokenizer(text_input,return_tensors="pt",truncation=True).to("cuda")
    logits = policy_model(**inputs).logits
    dist = torch.distributions.Categorical(logits=logits[:,-1,:])
    action = dist.sample()
    log_prob = dist.log_prob(action)

    obs , reward, done, info = env.step(action.item())

    observations.append(obs)
    actions.append(action)
    rewards.append(reward)
    texts.append(text_input)
    old_log_probs.append(log_prob)

    if done :
      obs = env.reset()

  trainer.train(observations,texts,actions,rewards,old_log_probs)
  print(f"Epoch {epoch} done\n------------------")


policy_loss : 184.87246704101562
value_loss : 0.04459132254123688
loss : 184.91705322265625
Epoch 0 done
------------------
policy_loss : 224.5618896484375
value_loss : 0.0593060627579689
loss : 224.62120056152344
Epoch 1 done
------------------
policy_loss : 8.639435768127441
value_loss : 0.049078211188316345
loss : 8.68851375579834
Epoch 2 done
------------------
policy_loss : 18.305387496948242
value_loss : 0.05003967881202698
loss : 18.355426788330078
Epoch 3 done
------------------
policy_loss : 133.70272827148438
value_loss : 0.03206789866089821
loss : 133.73480224609375
Epoch 4 done
------------------
policy_loss : 176.78652954101562
value_loss : 0.03516224026679993
loss : 176.82168579101562
Epoch 5 done
------------------
policy_loss : 0.09778323769569397
value_loss : 0.013048294000327587
loss : 0.11083152890205383
Epoch 6 done
------------------
policy_loss : 3.5736701488494873
value_loss : 0.03998035565018654
loss : 3.6136505603790283
Epoch 7 done
------------------
policy_lo

In [12]:
torch.save(policy_model, "DialoGPT_Train_With_PPO1.pth")