In [1]:
import torch
from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead

from environment import Chat
from actor_critic_net import ActorCritic
from agent import PPO
from utils import trainer, Predictor

In [2]:
# set device to cpu or cuda
if torch.backends.mps.is_available():
    device = torch.device('mps') 
    print("Device set to : mps")
elif(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    device = torch.device('cpu')
    print("Device set to : cpu")

Device set to : mps


In [None]:
tokenizer = AutoTokenizer.from_pretrained('huggingtweets/elonmusk')
model = AutoModelWithLMHead.from_pretrained('huggingtweets/elonmusk').to(device)
feedback_pipe = pipeline('sentiment-analysis',
                        model="cardiffnlp/twitter-roberta-base-sentiment",
                        tokenizer="cardiffnlp/twitter-roberta-base-sentiment",
                        return_all_scores=True, 
                        device=device)

In [4]:
####### initialize environment hyperparameters ######

max_ep_len = 20    # max timesteps in one episode
obs_base = ['I think Tesla is', 
            'I think dogecoin is', 
            'I think BTC is', 
            'I think Twitter is', 
            'I think ElonMusk is']    # init observations that will use to run the episodes

################ PPO hyperparameters ################
epochs = 50    # update policy for K epochs
eps_clip = 0.2    # clip parameter for PPO
gamma = 1    # discount factor

lr_actor = 5e-6    # learning rate for actor network
lr_critic = 5e-6    # learning rate for critic network

################## Trainer Setting ##################
max_training_timesteps = 30000
print_freq = 100
update_timestep = max_ep_len * 100
save_model_freq = max_ep_len * 100
checkpoint_path = "PPO_model.pth"
eval_obs = 'I think Tesla is'    # observation use to check performance after model update


In [5]:
env = Chat(llm=model, tokenizer=tokenizer, reward_pipe=feedback_pipe, obs_base=obs_base, max_gen_len=max_ep_len)
policy = ActorCritic(model).to(device)
ppo_agent = PPO(policy=policy, lr_actor=lr_actor, lr_critic=lr_critic, gamma=gamma, K_epochs=epochs, eps_clip=eps_clip)

### train with RLHF

In [6]:
trainer(
    env,
    ppo_agent,
    max_training_timesteps=max_training_timesteps,
    max_ep_len=max_ep_len,
    update_timestep=update_timestep,
    save_model_timestep=save_model_freq,
    print_freq=print_freq,
    checkpoint_path=checkpoint_path,
    eval_obs=eval_obs
)

Episode : 5 		 Timestep : 100 		 Average Reward : 1.65
Episode : 12 		 Timestep : 200 		 Average Reward : 0.69
Episode : 20 		 Timestep : 300 		 Average Reward : 0.57
Episode : 25 		 Timestep : 400 		 Average Reward : 1.16
Episode : 32 		 Timestep : 500 		 Average Reward : 1.09
Episode : 38 		 Timestep : 600 		 Average Reward : 1.98
Episode : 43 		 Timestep : 700 		 Average Reward : 3.29
Episode : 51 		 Timestep : 800 		 Average Reward : 1.45
Episode : 58 		 Timestep : 900 		 Average Reward : 1.12
Episode : 64 		 Timestep : 1000 		 Average Reward : 2.55
Episode : 70 		 Timestep : 1100 		 Average Reward : 2.33
Episode : 78 		 Timestep : 1200 		 Average Reward : 0.66
Episode : 85 		 Timestep : 1300 		 Average Reward : 0.56
Episode : 92 		 Timestep : 1400 		 Average Reward : 1.98
Episode : 100 		 Timestep : 1500 		 Average Reward : 1.94
Episode : 106 		 Timestep : 1600 		 Average Reward : 1.79
Episode : 112 		 Timestep : 1700 		 Average Reward : 0.36
Episode : 119 		 Timestep : 1800 		 Av

Episode : 765 		 Timestep : 11000 		 Average Reward : 0.26
Episode : 775 		 Timestep : 11100 		 Average Reward : 1.73
Episode : 782 		 Timestep : 11200 		 Average Reward : 2.65
Episode : 792 		 Timestep : 11300 		 Average Reward : 0.05
Episode : 800 		 Timestep : 11400 		 Average Reward : 3.8
Episode : 810 		 Timestep : 11500 		 Average Reward : 1.38
Episode : 816 		 Timestep : 11600 		 Average Reward : 1.92
Episode : 823 		 Timestep : 11700 		 Average Reward : 1.63
Episode : 830 		 Timestep : 11800 		 Average Reward : 0.49
Episode : 839 		 Timestep : 11900 		 Average Reward : 3.07
generated sentence: I think Tesla is great at Tesla stock price manipulation scams But hey, Tesla (Tesla AG), Twitter boy (company) 
 score: 0.0739921722561121
Episode : 849 		 Timestep : 12000 		 Average Reward : 1.4
--------------------------------------------------------------------------------------------
saving model at : PPO_model.pth
model saved
--------------------------------------------------------

Episode : 1851 		 Timestep : 21800 		 Average Reward : 5.19
Episode : 1858 		 Timestep : 21900 		 Average Reward : 2.25
generated sentence: I think Tesla is already getting ready for launch. Tesla & Tesla AI are close to commercialization<|endoftext|> 
 score: 0.023722532205283642
Episode : 1867 		 Timestep : 22000 		 Average Reward : 1.67
--------------------------------------------------------------------------------------------
saving model at : PPO_model.pth
model saved
--------------------------------------------------------------------------------------------
Episode : 1881 		 Timestep : 22100 		 Average Reward : 1.24
Episode : 1892 		 Timestep : 22200 		 Average Reward : 0.97
Episode : 1901 		 Timestep : 22300 		 Average Reward : 0.87
Episode : 1911 		 Timestep : 22400 		 Average Reward : 2.16
Episode : 1924 		 Timestep : 22500 		 Average Reward : 1.61
Episode : 1935 		 Timestep : 22600 		 Average Reward : 2.94
Episode : 1951 		 Timestep : 22700 		 Average Reward : 1.2
Episode :

### Inference

In [37]:
env = Chat(llm=model, tokenizer=tokenizer, reward_pipe=feedback_pipe, obs_base=obs_base, max_gen_len=max_ep_len)
policy = ActorCritic(model, temperature=1, top_p=1, top_k=0).to(device)

In [38]:
predictor = Predictor(env, policy)
predictor.load_model(checkpoint_path)

In [40]:
predictor.predict('I think Tesla is')

'I think Tesla is bad enough of it …<|endoftext|>'