In [13]:
import os
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.getcwd(), '..', 'transformers')
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [14]:
import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import logging
import sys

In [15]:
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

checkpoint = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model.eval()
model.cuda()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [16]:
# sentiment = pipeline('sentiment-analysis',model="cardiffnlp/twitter-roberta-base-sentiment",tokenizer="cardiffnlp/twitter-roberta-base-sentiment",device=0,return_all_scores=True)
sentiment_path_1 = "cardiffnlp/twitter-roberta-base-sentiment" # 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
sentiment_path_2 = 'nlptown/bert-base-multilingual-uncased-sentiment'
sentiment_1 = pipeline('sentiment-analysis',model=sentiment_path_1,tokenizer=sentiment_path_1, top_k=None)
sentiment_2 = pipeline('sentiment-analysis',model=sentiment_path_2,tokenizer=sentiment_path_2, top_k=None)

In [17]:
def return_score_nlptown(sentence, percent=1):
    scores = sentiment_2(sentence)[0]
    req_score = 0
    for dict_inner in scores:
        if dict_inner['label'] == '1 star':
            req_score += percent * dict_inner['score']
        if dict_inner['label'] == '2 star':
            req_score += (1 - percent) * dict_inner['score']

    return req_score

def return_score_cardiff_roberta_xlm(sentence):
    scores = sentiment_1(sentence)[0]
    req_score = 0
    for dict_inner in scores:
        if dict_inner['label'] == 'negative':
            req_score += 10 * dict_inner['score']
        if dict_inner['label'] == 'positive':
            req_score -= 10 * dict_inner['score']
        if dict_inner['label'] == 'neutral':
            req_score -= 5 * dict_inner['score']

    return req_score

In [18]:
class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
      reward = 0
      if finish or len(predicted_list[0]) >= self.env_max_length:
        predicted_text = tokenizer.convert_tokens_to_string(predicted_list[0])
        # sentiment classifier
        # reward = sentiment(input_item['input']+predicted_text)[0][0]['score'] * 10
        reward = return_score_cardiff_roberta_xlm(input_item['input']+predicted_text)
        reward = min(reward, return_score_nlptown(input_item['input']+predicted_text))
      return [reward]

In [19]:
observation_list = [{"input": "Dogecoin is "}]
env = MyRLEnv(
    model, tokenizer, 
    observation_input=observation_list, 
    max_length=25, 
    compare_sample=1
)
actor = TextRLActor(
    env, model, tokenizer, 
    act_deterministically=False, 
    optimizer='adamw',
    temperature=0.8,
    top_k=100,
    top_p=0.85
)
agent = actor.agent_ppo(
    update_interval=2, 
    minibatch_size=2, 
    epochs=10,
    lr=3e-4
)

In [20]:
test_phrase = "Dogecoin is "
print(test_phrase)
actor.predict({'input': test_phrase})

Dogecoin is 


['!!! !!! And now they have been hacked by a bunch of idiots, who claim to be the most powerful crypto currency,']

In [21]:
train_agent_with_evaluation(
    agent,
    env,
    steps=1000,
    eval_n_steps=None,
    eval_n_episodes=1,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='checkpoint', 
)

outdir:checkpoint step:27 episode:0 R:0
statistics:[('average_value', -1.1544421), ('average_entropy', 1.0630808), ('average_value_loss', 0.044745933401300134), ('average_policy_loss', -0.13790720641613008), ('n_updates', 130), ('explained_variance', -29.32518143973233)]
evaluation episode 0 length:27 R:0
The best score is updated -3.4028235e+38 -> 0.0
Saved the agent to checkpoint/best
outdir:checkpoint step:54 episode:1 R:0
statistics:[('average_value', -0.55661476), ('average_entropy', 0.554348), ('average_value_loss', 3.0574019629741045e-05), ('average_policy_loss', -5.960464477539063e-09), ('n_updates', 270), ('explained_variance', nan)]
evaluation episode 0 length:27 R:0
outdir:checkpoint step:81 episode:2 R:0
statistics:[('average_value', -0.45449057), ('average_entropy', 0.66024405), ('average_value_loss', 0.007434408800691017), ('average_policy_loss', -0.1364479586482048), ('n_updates', 400), ('explained_variance', -12.121727398388678)]
evaluation episode 0 length:27 R:0
outdi

(<textrl.actor.TextPPO at 0x7ff800073a90>,
 [{'average_value': -1.1544421,
   'average_entropy': 1.0630808,
   'average_value_loss': 0.044745933401300134,
   'average_policy_loss': -0.13790720641613008,
   'n_updates': 130,
   'explained_variance': -29.32518143973233,
   'eval_score': 0.0},
  {'average_value': -0.55661476,
   'average_entropy': 0.554348,
   'average_value_loss': 3.0574019629741045e-05,
   'average_policy_loss': -5.960464477539063e-09,
   'n_updates': 270,
   'explained_variance': nan,
   'eval_score': 0.0},
  {'average_value': -0.45449057,
   'average_entropy': 0.66024405,
   'average_value_loss': 0.007434408800691017,
   'average_policy_loss': -0.1364479586482048,
   'n_updates': 400,
   'explained_variance': -12.121727398388678,
   'eval_score': 0.0},
  {'average_value': -0.36626735,
   'average_entropy': 0.5828616,
   'average_value_loss': 0.002342577498893661,
   'average_policy_loss': -0.0152191624045372,
   'n_updates': 540,
   'explained_variance': nan,
   'eval

In [22]:
agent.load('./checkpoint/best')
print(observation_list[0])
actor.predict(observation_list[0])

{'input': 'Dogecoin is '}


['urchins just just like it just just just just just just just just just just just just just just just just just just just']