In [3]:
import torch
from transformers import GPT2Tokenizer
from trl import PPOConfig, create_reference_model, AutoModelForCausalLMWithValueHead, PPOTrainer
from trl.core import respond_to_batch

# 1. load a pretrained model
model_encoder = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
#model_decoder = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')

model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# 2. initialize trainer
ppo_config = {'batch_size': 16}
config = PPOConfig(**ppo_config)
ppo_encoder_trainer = PPOTrainer(config, model_encoder, model_ref, tokenizer)
#ppo_decoder_trainer = PPOTrainer(config, model_decoder, model_ref, tokenizer) # remove model_ref?

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [2]:
!export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

In [298]:
import torch
from transformers import GPT2Tokenizer

class StegEnv():
    def __init__(self):
        self.prompts = [" 0 1 0 1", "This morning I went to the ", "The weather today is ", "What is your favorite "]
        self.secrets = [' 1 1 0 1',
                        ' 1 0 1 0',
                        ' 0 1 0 1',
                        ' 1 1 1 1',
                        ' 0 0 1 1',
                        ' 1 1 1 0',
                        ' 1 0 0 1',
                        ' 0 0 0 1',
                        ' 0 1 0 0',
                        ' 1 0 1 1',
                        ' 0 1 1 1',
                        ' 0 0 0 0',
                        ' 1 1 0 0',
                        ' 0 1 1 0',
                        ' 0 0 1 0',
                        ' 1 0 0 0']

        self.prompt_batch = None
        self.secret_batch = None

        self.batch_size = 16 # must be the same as value set in ppo_config

        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def _get_queries(self):
        return ["secret:" + secret + " prompt:" + prompt for secret, prompt in zip(self.secret_batch, self.prompt_batch)]
    
    def _get_obs(self):
        queries = self._get_queries()
        query_pt = [tokenizer.encode(query, return_tensors="pt").squeeze() for query in queries]
        prompt_pt = [tokenizer.encode(prompt, return_tensors="pt").squeeze() for prompt in self.prompt_batch]
        return {
            "query": queries,
            "query_pt": query_pt,
            "prompt": self.prompt_batch,
            "prompt_pt": prompt_pt,
            "secret": self.secret_batch
        }

    def reset(self, ):
        prompt_idxs = torch.randint(len(self.prompts), size=(self.batch_size,))
        secret_idxs = torch.randint(len(self.secrets), size=(self.batch_size,))

        self.prompt_batch = [self.prompts[idx] for idx in prompt_idxs]
        self.secret_batch = [self.secrets[idx] for idx in secret_idxs]

        print(self.prompt_batch, self.secret_batch)

        obs = self._get_obs()

        return obs

    def _reward_function(self, responses_encoder, responses_decoder):

        rewards_encoder = []
        rewards_decoder = []

        for response_enc, response_dec, secret in zip(responses_encoder, responses_decoder, self.secret_batch):
            secret_pt = tokenizer.encode(secret, return_tensors="pt").squeeze()

            # if response_dec_txt == secret:
            #     reward_enc, reward_dec = 1.0, 1.0
            # else:
            #     reward_enc, reward_dec = 0.0, 0.0

            reward_enc = reward_dec = (secret_pt == response_dec).sum()

            rewards_encoder.append(torch.tensor(reward_enc))
            rewards_decoder.append(torch.tensor(reward_dec))

        return (rewards_encoder, rewards_decoder)

    def step(self, responses_encoder, responses_decoder):

        reward = self._reward_function(responses_encoder, responses_decoder)
        return reward



In [310]:
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

def get_responses(obs, ppo_encoder_trainer):

    enc_response_len = 4
    dec_response_len = 4
    
    encoder_queries = obs['query_pt']
    decoder_queries = []
    encoder_responses = []
    decoder_responses = []

    for query_encoder_pt, prompt_encoder_pt in zip(encoder_queries, obs['prompt_pt']):

        # get encoder response
        generation_kwargs["max_new_tokens"] = enc_response_len
        response_encoder_pt = ppo_encoder_trainer.generate(query_encoder_pt, **generation_kwargs)
        response_encoder_pt = response_encoder_pt.squeeze()[-enc_response_len:]
        encoder_responses.append(response_encoder_pt)

        # construct decoder query
        query_decoder_pt = torch.cat((prompt_encoder_pt, response_encoder_pt))
        decoder_queries.append(query_decoder_pt)

        # get decoder response
        generation_kwargs["max_new_tokens"] = dec_response_len
        response_decoder_pt = ppo_encoder_trainer.generate(query_decoder_pt, **generation_kwargs)
        response_decoder_pt = response_decoder_pt.squeeze()[-dec_response_len:]
        
        # response_decoder_txt = tokenizer.decode(response_decoder_pt.squeeze())
        decoder_responses.append(response_decoder_pt)

    return encoder_queries, encoder_responses, decoder_queries, decoder_responses

def train(ppo_encoder_trainer, ppo_decoder_trainer):

    num_episodes = 10
    env = StegEnv()

    for episode in range(num_episodes):
        obs = env.reset()
        encoder_queries, encoder_responses, decoder_queries, decoder_responses = get_responses(obs, ppo_encoder_trainer)
        encoder_rewards, decoder_rewards = env.step(encoder_responses, decoder_responses)

        print([tokenizer.decode(res) for res in encoder_responses])
        print([tokenizer.decode(res) for res in decoder_responses])
        print(encoder_rewards, decoder_rewards)

        stats1 = ppo_encoder_trainer.step(encoder_queries, encoder_responses, encoder_rewards,
                                          decoder_queries, decoder_responses, decoder_rewards)
        #stats2 = ppo_decoder_trainer.step(decoder_queries, decoder_responses, decoder_rewards)
        # log the stats?
    

In [308]:
train(ppo_encoder_trainer, ppo_decoder_trainer)

['This morning I went to the ', ' 0 1 0 1', ' 0 1 0 1', 'The weather today is ', ' 0 1 0 1', ' 0 1 0 1', 'What is your favorite ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', ' 0 1 0 1', ' 0 1 0 1', 'The weather today is ', ' 0 1 0 1', 'The weather today is ', 'What is your favorite '] [' 1 1 1 1', ' 1 0 1 0', ' 0 0 0 1', ' 0 1 1 0', ' 1 1 1 0', ' 1 1 1 0', ' 1 1 0 0', ' 0 1 1 0', ' 0 0 0 1', ' 0 1 1 1', ' 1 0 0 0', ' 0 0 1 0', ' 0 0 0 1', ' 0 0 1 0', ' 1 1 1 1', ' 0 1 1 1']


  rewards_encoder.append(torch.tensor(reward_enc))
  rewards_decoder.append(torch.tensor(reward_dec))


['__________ bathroom and', ' prompting: 0 1', ' phs: 30', '?"\n\n\nJoseph', ' 0 1 from:', ' 1\n\nExp', '_______? Kim:', '??????????????\n', '_________<d', '_______?" response:', ' epDomain: 1', ' response: 0 1', '_____ + value from', ' prompt: 0 0', '__________sible', '?"\n"I']
[' went straight to the', ' 1 0 0 1', ' 1 0 0 2', ' quotes the Norse Among', '1021\n\n', 'anderList Issue numerous', ' Great, to me', 'x25 ghel', '/Def Navy ', '\n\nTerminology', ' 1.750 450', ' 0 0 1 foo', ' _____ + 1', 'F40 5 2', '.\nNow,', ' like sugar."\n']
[tensor(0), tensor(2), tensor(2), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(3), tensor(1), tensor(0), tensor(0), tensor(0)] [tensor(0), tensor(2), tensor(2), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(3), tensor(1), tensor(0), tensor(0), tensor(0)]


In [309]:
results1

{'objective/kl': 1.2134313583374023,
 'objective/kl_dist': array([-2.412953  ,  0.45053107, -0.05035055,  3.8920205 ,  0.49100325,
         0.3752538 ,  0.9358165 ,  1.1888177 ,  0.05746937,  2.751363  ,
         0.720742  ,  0.50512636, -1.9741118 ,  0.06191796,  2.7339253 ,
         9.688329  ], dtype=float32),
 'objective/logprobs': array([[-4.67308617e+00, -7.71740484e+00, -8.19761276e+00,
         -2.61621785e+00, -1.27646339e+00, -1.36304712e+01,
         -5.95183671e-01, -1.00772238e+01, -7.56769896e+00,
         -2.14764166e+00, -3.47605848e+00, -3.39988410e-01,
         -1.33591378e+00, -9.00196838e+00, -3.23415303e+00,
         -9.52454627e-01, -7.66464090e+00, -1.38595629e+00],
        [-4.67308617e+00, -7.71740484e+00, -6.61554337e+00,
         -3.67517710e+00, -1.11819017e+00, -1.34183121e+01,
         -5.08503675e-01, -2.75345325e+00, -2.09984326e+00,
         -6.35269523e-01, -1.14638114e+00, -7.99811029e+00,
         -1.53382376e-01, -6.44834399e-01, -4.35148627e-01,
  