<a href="https://colab.research.google.com/github/alecbidaran/Pytorch_excersies/blob/main/RLHF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/nlp-with-transformers/notebooks.git
%cd notebooks
from install import *
install_requirements()

fatal: destination path 'notebooks' already exists and is not an empty directory.
/content/notebooks
⏳ Installing base requirements ...
✅ Base requirements installed!
⏳ Installing Git LFS ...
✅ Git LFS installed!


In [None]:
# hide
from utils import *
setup_chapter()

Using transformers v4.11.3
Using datasets v1.16.1


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer=AutoTokenizer.from_pretrained(model_name)
model=AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [None]:
from transformers import pipeline
sentiment_pipeline = pipeline('text-classification',"finiteautomata/bertweet-base-sentiment-analysis")

In [None]:
def get_reward(text,mode):
  ress=sentiment_pipeline(text)
  if mode=="+ve":
    labels=torch.tensor([res['label']=="POS" for res in ress],dtype=torch.float32).to(device)
  if mode=="-ve":
    labels=torch.tensor([res['label']=="NEG" for res in ress],dtype=torch.float32).to(device)
  scores=torch.tensor([res['score'] for res in ress],dtype=torch.float32).to(device)
  reward=labels*scores
  return reward



In [None]:
inps = ["I'm the king of the world!",
        "I'll be back.",
        "The cake is a lie",
        "To be forgotten is worse than death",
        "All happy families are alike; each unhappy family is unhappy in its own way.",
        "You don't need a reason to help people",
        ]
res = sentiment_pipeline(inps)

for i in range(len(inps)):
  res[i]['text'] = inps[i]
  print(res[i])

{'label': 'POS', 'score': 0.9771729707717896, 'text': "I'm the king of the
world!"}
{'label': 'POS', 'score': 0.5481614470481873, 'text': "I'll be back."}
{'label': 'NEG', 'score': 0.7581191658973694, 'text': 'The cake is a lie'}
{'label': 'NEG', 'score': 0.8209365606307983, 'text': 'To be forgotten is worse
than death'}
{'label': 'NEU', 'score': 0.7874236702919006, 'text': 'All happy families are
alike; each unhappy family is unhappy in its own way.'}
{'label': 'NEU', 'score': 0.8731080889701843, 'text': "You don't need a reason
to help people"}


In [None]:
get_reward(inps[1], '+ve')


tensor([0.5482], device='cuda:0')

In [None]:
import copy
ref_model=copy.deepcopy(model)
ref_model=ref_model.to(device)
for param in ref_model.parameters():
  param.requires_grad=False

In [None]:
def logp_from_logits(output,labels):
  probs=F.log_softmax(output,dim=-1)
  return torch.gather(probs,2,labels.unsqueeze(2)).squeeze(-1)


def generate(context,max_rate):
    log_proba=torch.tensor([]).to(device)
    labels=rlhf.generate(context,max_length=max_rate,do_sample=True,top_p=0.9,top_k=20)
    att_mask=torch.ones_like(labels)
    output = rlhf(labels,attention_mask=att_mask)
    log_probs = logp_from_logits(
    output.logits[:, :-1, :], labels[:, 1:])
    log_proba=torch.cat((log_proba,log_probs),dim=1)
    entropy=F.softmax(rlhf(context).logits,dim=-1)
    return labels,log_proba,entropy

def refrence_generate(context,max_rate):
   log_ref_proba=torch.tensor([]).to(device)
   with torch.no_grad():
        ref_labels=ref_model.generate(context,max_length=max_rate,do_sample=True,top_p=0.9,top_k=20)
        att_mask=torch.ones_like(ref_labels).to(device)
        ref_output = ref_model(ref_labels,attention_mask=att_mask)
   log_ref_probs = logp_from_logits(
   ref_output.logits[:, :-1, :], ref_labels[:, 1:])
   log_ref_proba=torch.cat((log_ref_proba,log_ref_probs),dim=1)
   return log_ref_proba



In [None]:
import numpy as np
rlhf=model
for param in rlhf.parameters():
  param.requires_grad=True
optimizer=torch.optim.AdamW(rlhf.parameters(),lr=1e-4,betas=(0.9,0.98),eps=1e-6)


In [None]:
def update_policy(input_txt):
  losses=[]
  actor_rewards=[]
  input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
  log_ref_proba=refrence_generate(input_ids,max_rate=64)
  seq_log_ref=torch.sum(log_ref_proba)
  for _ in range(10):
    optimizer.zero_grad()
    states,log_proba,entropy=generate(input_ids,max_rate=64)
    seq_log=torch.sum(log_proba)
    sentences=tokenizer.decode(states[0].tolist()).split(".")
    rewards=torch.cat([get_reward(s,"-ve") for s in sentences]).sum()
    ratio=(rewards-0.1*(seq_log-seq_log_ref))
    loss_proximy=ratio*log_proba.exp().squeeze(0)
    pol1=loss_proximy
    pol2=torch.clamp(ratio,1-0.2,1+0.2)*log_proba.exp().squeeze(0)
    loss=-0.5*torch.min(pol1,pol2).sum()+1e-3*entropy.mean()
    loss.backward()
    optimizer.step()
    actor_rewards.append(rewards.detach().cpu().numpy())
    losses.append(loss.item())
  return losses,actor_rewards,states


In [None]:
eval_interval_rlhf = 10
max_iters_rlhf = 100 # start with ''The'
input_txt = "The"
count=0
plot_loss=[]
for iter in range(max_iters_rlhf):
  loss,actor_rewards,states=update_policy(input_txt)
  if iter%eval_interval_rlhf:
    plot_loss.append(loss)
    print('\n')
    print(f'loss: {np.mean(loss)}')
    print(f'rewards:{np.mean(actor_rewards)}')
    print(f'outputs: {tokenizer.decode(states[0].tolist())}')
    if np.mean(actor_rewards)>0.9:
      count+=1
    if count==2:
      break




loss: 5.523060083389282
rewards:1.0237746238708496
outputs: The first stage of the internet is now dominated by virtual reality
headsets like Oculus Rift, HTC Vive, and Oculus Rift Pro. The new generation of
gaming rigs are also using the Rift Touchpad with SteamVR to connect to virtual
reality games with HTC Vive's SteamVR app and Oculus Touchpad controller. These
rigs are also


loss: -4.805917572975159
rewards:0.5005065202713013
outputs: The state's largest prison population was also plagued by gang
violence, including an armed robbery.

The FBI's Counter-Terrorism Center, meanwhile, is run by convicted felon Jeffrey
Dahmer. The state prison in Chicago was home to the infamous Boston Marathon
bombing.

With the internet, cellphones and even


loss: -7.798302173614502
rewards:0.8612753748893738
outputs: The first major American city to ban the LGBT community.

In 2015, President Obama's former national security adviser Michael Flynn was
accused of having ties to Russia.

And while 

In [None]:
max_length = 128
input_txt = """The
"""
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
with torch.no_grad():
  output_temp=rlhf.generate(input_ids, max_length=128, do_sample=True,top_k=20,top_p=0.9)
print(tokenizer.decode(output_temp[0]))

The

(and other) scenarios in which you might be tempted to cheat on. But it's not so
easy, so

maybe not the world's biggest risk.

As the internet's biggest user of sex, the worst of those things may be over.

In case you were just getting too creative, here's how to fight against
sex-induced stress.

But there's one place where you're just making up.

And it's not all that scary.

In fact, there's one thing the Internet's doing that's not doing well.

A little bit of
