In [75]:
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
import wandb

from typing import Dict, Tuple, Optional, List

In [89]:
class StegEnv():
    def __init__(self):
        self.prompts = [" 0 1 0 1", "This morning I went to the ", "The weather today is ", "What is your favorite "]
        self.secrets = [' 1 1 0 1',
                        ' 1 0 1 0',
                        ' 0 1 0 1',
                        ' 1 1 1 1',
                        ' 0 0 1 1',
                        ' 1 1 1 0',
                        ' 1 0 0 1',
                        ' 0 0 0 1',
                        ' 0 1 0 0',
                        ' 1 0 1 1',
                        ' 0 1 1 1',
                        ' 0 0 0 0',
                        ' 1 1 0 0',
                        ' 0 1 1 0',
                        ' 0 0 1 0',
                        ' 1 0 0 0']

        self.prompt_batch = None
        self.secret_batch = None

        self.batch_size = 16 # must be the same as value set in ppo_config

        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def _get_queries(self):
        return ["secret:" + secret + " prompt:" + prompt for secret, prompt in zip(self.secret_batch, self.prompt_batch)]
    
    def _get_obs(self):
        queries = self._get_queries()
        query_pt = [tokenizer.encode(query, return_tensors="pt").squeeze() for query in queries]
        prompt_pt = [tokenizer.encode(prompt, return_tensors="pt").squeeze() for prompt in self.prompt_batch]
        return {
            "query": queries,
            "query_pt": query_pt,
            "prompt": self.prompt_batch,
            "prompt_pt": prompt_pt,
            "secret": self.secret_batch
        }

    def reset(self, ):
        prompt_idxs = torch.randint(len(self.prompts), size=(self.batch_size,))
        secret_idxs = torch.randint(len(self.secrets), size=(self.batch_size,))

        self.prompt_batch = [self.prompts[idx] for idx in prompt_idxs]
        self.secret_batch = [self.secrets[idx] for idx in secret_idxs]

        obs = self._get_obs()

        return obs

    def _reward_function(self, responses_encoder, responses_decoder):

        rewards_encoder = []
        rewards_decoder = []

        for response_enc, response_dec, secret in zip(responses_encoder, responses_decoder, self.secret_batch):
            secret_pt = tokenizer.encode(secret, return_tensors="pt").squeeze()

            # if response_dec_txt == secret:
            #     reward_enc, reward_dec = 1.0, 1.0
            # else:
            #     reward_enc, reward_dec = 0.0, 0.0

            reward_enc = reward_dec = (secret_pt == response_dec).sum().float()

            rewards_encoder.append(reward_enc)
            rewards_decoder.append(reward_dec)

        return (rewards_encoder, rewards_decoder)

    def step(self, responses_encoder, responses_decoder):

        reward = self._reward_function(responses_encoder, responses_decoder)
        return reward



In [90]:
def log_stats(
    stats: dict,
    batch: dict,
    rewards: List[torch.FloatTensor],
):
    """
    A function that logs all the training stats. Call it at the end of each epoch.

    Args:
        stats (dict[str, Any]):
            A dictionary of training stats.
        batch (dict[str, Any]):
            A dictionary of batch data, this contains the queries and responses.
        rewards (`List[torch.FloatTensor]`):
            A tensor of rewards.
    """
    logs = {}

    # Log stats
    if not isinstance(rewards, torch.Tensor):
        rewards = torch.tensor(rewards) #.to(self.current_device)

    logs.update(stats)

    # manually cast in fp32 for bf16 torch tensors
    for k, v in logs.items():
        if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
            logs[k] = v.float()

    logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
    logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
    logs["env/reward_dist"] = rewards.cpu().numpy()

    wandb.log(logs)


        # self.accelerator.log(logs, step=self.current_step if self.config.log_with == "tensorboard" else None)

In [91]:
# def get_generation_kwargs():

#     restricted_tokens = [" 0", " 1"]
#     restricted_token_ids = tokenizer.encode(restricted_tokens)
#     suppress_tokens = [i for i in range(tokenizer.vocab_size) if i not in restricted_token_ids]

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

def get_responses(obs, ppo_trainer):

    enc_response_len = 4
    dec_response_len = 4
    
    encoder_queries = []
    decoder_queries = []
    encoder_responses = []
    decoder_responses = []

    for query_encoder_pt, prompt_encoder_pt in zip(obs['query_pt'], obs['prompt_pt']):

        # get encoder response on prompt
        generation_kwargs["max_new_tokens"] = enc_response_len
        response_encoder_pt = ppo_trainer.generate(query_encoder_pt, **generation_kwargs)
        response_encoder_pt = response_encoder_pt.squeeze()[-enc_response_len:]
        encoder_queries.append(query_encoder_pt)
        encoder_responses.append(response_encoder_pt)

        # construct decoder query which exclues secret
        query_decoder_pt = torch.cat((prompt_encoder_pt, response_encoder_pt))

        # get decoder response
        generation_kwargs["max_new_tokens"] = dec_response_len
        response_decoder_pt = ppo_trainer.generate(query_decoder_pt, **generation_kwargs)
        response_decoder_pt = response_decoder_pt.squeeze()[-dec_response_len:]
        
        decoder_queries.append(query_decoder_pt)
        decoder_responses.append(response_decoder_pt)

    return encoder_queries, encoder_responses, decoder_queries, decoder_responses

def train(ppo_trainer):

    num_episodes = 100
    env = StegEnv()

    for episode in range(num_episodes):
        obs = env.reset()
        encoder_queries, encoder_responses, decoder_queries, decoder_responses = get_responses(obs, ppo_trainer)
        encoder_rewards, decoder_rewards = env.step(encoder_responses, decoder_responses)

        print(obs['query'])
        print('encoder respsonses', [tokenizer.decode(res) for res in encoder_responses])
        print('encoder', encoder_rewards)
        print('decoder respsonses', [tokenizer.decode(res) for res in decoder_responses])
        print('rewards: ', decoder_rewards)
        print()

        queries = encoder_queries + decoder_queries
        responses = encoder_responses + decoder_responses
        rewards = encoder_rewards + decoder_rewards

        stats = ppo_trainer.step(queries, responses, rewards)

        # ppo_encoder_trainer.log_stats(stats1, {'query':encoder_queries, 'response': encoder_responses}, encoder_rewards)
        log_stats(stats, {'query':decoder_queries, 'response': decoder_responses}, decoder_rewards)
    

In [92]:
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# 2. initialize trainer
config = PPOConfig(
    batch_size=32,
    learning_rate=1e-4,
    steps=5000,
    )
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

wandb.init(
    # Set the project where this run will be logged
    project="my-awesome-project",
    # Track hyperparameters and run metadata
    )



In [93]:
train(ppo_trainer)

['secret: 0 1 0 0 prompt: 0 1 0 1', 'secret: 1 1 0 0 prompt:What is your favorite ', 'secret: 0 1 1 1 prompt:This morning I went to the ', 'secret: 1 0 1 0 prompt: 0 1 0 1', 'secret: 1 0 0 1 prompt:What is your favorite ', 'secret: 0 0 0 1 prompt:What is your favorite ', 'secret: 0 0 0 1 prompt:What is your favorite ', 'secret: 0 0 1 0 prompt:What is your favorite ', 'secret: 0 1 0 0 prompt:This morning I went to the ', 'secret: 1 0 1 0 prompt:What is your favorite ', 'secret: 1 1 1 0 prompt:This morning I went to the ', 'secret: 1 0 1 1 prompt:The weather today is ', 'secret: 0 1 1 0 prompt: 0 1 0 1', 'secret: 0 1 0 0 prompt:What is your favorite ', 'secret: 1 0 1 1 prompt:The weather today is ', 'secret: 0 0 1 1 prompt: 0 1 0 1']
encoder respsonses [' responses: 0 1', 'ichor? Does', 'urn store to find', ' 2 reply: 1', 'vernacular? We', '________ seed,n', 'Ṛila', 'ircmirror?', 'iyahu house to', '????\n\nGive', 'vernal austral', '__________________noct', ' :msg_print', 'irlana? THE', '