In [45]:
import torch
from transformers import AutoTokenizer, AutoModel
from trl_custom import PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
import wandb
from typing import Dict, Tuple, Optional, List

class StegEnv():
    def __init__(self, 
                tokenizer: AutoTokenizer,
                batch_size: int = 16,
                device: str = 'cpu',
                trainer,
            ):
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
        self.trainer = trainer
        
        self.key_length = 1 
        self.enc_response_len = 4
        self.dec_response_len = max(self.key_length, 4) # must be >= 4 for now because of ppo_train

        if self.key_length < self.dec_response_len:
            print("Warning: key length is less than decoder response length. Reward will only use first [key_length] tokens.")

        self.key_tokens = [' 0', ' 1', ' 2', ' 3', ' 4', ' 5', ' 6', ' 7', ' 8', ' 9']
        self.prompts = [" 0 7 3 8 4", "Yesterday I went to ", "The weather today is ", "What is your favorite "]

        self.prompts_pt = self.tokenize_batch(self.prompts)
        self.key_tokens_pt = self.tokenize_batch(self.key_tokens, squeeze=True)

        self.prompt_batch = None
        self.key_batch = None

        key_buff = "Here is the key:"
        prompt_buff = " prompt:"
        encode_buff = ". Now repeat the key:"
        message_buff = "Message:"
        decode_buff = ". Now extract the key:"
        
        self.buffs = dict()
        
        self.buffs["key"] = self.tokenize_batch(key_buff, repeat=True)
        self.buffs["prompt"] = self.tokenize_batch(prompt_buff, repeat=True)
        self.buffs["encode"] = self.tokenize_batch(encode_buff, repeat=True)
        self.buffs["message"] = self.tokenize_batch(message_buff, repeat=True)
        self.buffs["decode"] = self.tokenize_batch(decode_buff, repeat=True)
    
    def tokenize_batch(self, string, repeat=False, squeeze=False):
        tokenized = self.tokenizer(string, return_tensors='pt', padding=True)['input_ids']
        if repeat:
            tokenized = tokenized.repeat(self.batch_size,1)
        if squeeze:
            tokenized = tokenized.squeeze()
        return tokenized.to(self.device)

    def _get_encoder_query(self, key, prompt):
        """
        Generates queries of the form "$key:[key]$prompt:[prompt]"
        """
        return torch.cat((self.buffs["key"], key, self.buffs["encode"]), dim=-1) #self.buffs["encode"]
        # return torch.cat((self.key_buff_batch, key, prompt_buff_batch, self.prompt_batch), dim=-1)

    def _get_decoder_query(self, prompt, response):
        #return torch.cat((response, self.key_buff_batch), dim=-1)
        return torch.cat((self.buffs["message"], response[:, :self.key_length], self.buffs["decode"]), dim=-1)
    
    def _get_obs(self):

        return {
            "query": self.query_batch,
            "prompt": self.prompt_batch,
            "key": self.key_batch
        }

    def reset(self, ):
        
        prompt_idxs = torch.randint(len(self.prompts_pt), size=(self.batch_size,))
        key_idxs = torch.randint(len(self.key_tokens_pt), size=(self.batch_size, self.key_length))

        self.prompt_batch = self.prompts_pt[prompt_idxs]
        self.key_batch = self.key_tokens_pt[key_idxs]
        print('key batch', self.key_batch.shape)
        self.query_batch = self._get_encoder_query(self.key_batch, self.prompt_batch)
        
        obs = self._get_obs()

        return obs

    def _reward_function(self, enc_query, enc_response, dec_query, dec_response):

        dec_response = dec_response[:, :self.key_length] # only use first [key_length] tokens
        print("decoder_response: ", dec_response)
        print("self key batch: ", self.key_batch)
        #reward_encoder = self.trainer.model_ref(enc_query, enc_response)
        # ref_logprobs, _, _, _ = self.batched_forward_pass(self.trainer.ref_model, queries, responses, model_inputs)
        reward_encoder = reward_decoder = (dec_response == self.key_batch).sum(dim=-1).float()
        return (reward_encoder, reward_decoder)

    def step(self, enc_query, enc_response, dec_query, dec_response):

        reward = self._reward_function(enc_query, enc_response, dec_query, dec_response)
        return reward


class StegPPOTrainer():
    def __init__(self,
            config: dict,
            model: AutoModel,
            model_ref: AutoModel,
            tokenizer: AutoTokenizer,
        ):
        
        self.model = model
        self.model_ref = model_ref
        self.tokenizer = tokenizer

        batch_size = config['batch_size']
        self.episodes = config['episodes']
        self.device = config['device']
        self.multi_agent = config['multi_agent']

        # initialize environment
        self.env = StegEnv(
            tokenizer = self.tokenizer,
            batch_size=batch_size,
            device=device,
            model_ref=model_ref
        )

        self.enc_gen_kwargs = {
            "min_length": -1,
            "top_k": 0.0,
            "top_p": 1.0,
            "do_sample": True,
            "pad_token_id": self.tokenizer.eos_token_id,
            "max_new_tokens": self.env.enc_response_len
        }

        # get tokens to suppress
        self.suppress_tokens = [i for i in range(self.tokenizer.vocab_size) if i not in self.env.key_tokens_pt]

        self.dec_gen_kwargs = {
            **self.enc_gen_kwargs,
            "suppress_tokens": self.suppress_tokens, 
            "max_new_tokens": self.env.dec_response_len, 
        }

        config = PPOConfig(
            batch_size= batch_size * 2 if multi_agent else batch_size, # double for encoder + decoder responses
            learning_rate=config['learning_rate'],
            steps=config['steps'],
            )
            
        self.ppo_trainer = PPOTrainer(config, self.model, self.model_ref, self.tokenizer)
        
    def log_stats(
        self,
        stats: dict,
        rewards: List[torch.FloatTensor],
    ):
        """
        A function that logs all the training stats. Call it at the end of each epoch.
        """
        logs = {}

        # Log stats
        if not isinstance(rewards, torch.Tensor):
            rewards = torch.tensor(rewards).to(self.device)

        logs.update(stats)

        # manually cast in fp32 for bf16 torch tensors
        for k, v in logs.items():
            if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
                logs[k] = v.float()

        logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
        logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
        logs["env/reward_dist"] = rewards.cpu().numpy()

        #wandb.log(logs)

    def get_model_responses(self, obs):

        encoder_query = obs['query']
        encoder_response = self.model.generate(encoder_query, **self.enc_gen_kwargs) # should this be ppo_trainer.generate????
        encoder_response = encoder_response[:, -self.enc_gen_kwargs["max_new_tokens"]:]

        decoder_query = self.env._get_decoder_query(obs['prompt'], encoder_response)
        decoder_response = self.model.generate(decoder_query, **self.dec_gen_kwargs)
        decoder_response = decoder_response[:, -self.dec_gen_kwargs["max_new_tokens"]:]

        return encoder_query, encoder_response, decoder_query, decoder_response

    def train(self):

        for _ in range(self.episodes):
            obs = self.env.reset()
            enc_query, enc_response, dec_query, dec_response = self.get_model_responses(obs)
            enc_reward, dec_reward = self.env.step(enc_query, enc_response, dec_query, dec_response)

            print('-----------------------------------------------------------------------')
            print('prompt, keys:')
            print(self.tokenizer.batch_decode(obs['prompt']))
            print(self.tokenizer.batch_decode(obs['key']))
            print('\nencoder:')
            print(self.tokenizer.batch_decode(obs['query']))
            print(self.tokenizer.batch_decode(enc_response))
            print(enc_reward)
            print('\ndecoder:')
            print(self.tokenizer.batch_decode(dec_query))
            print(self.tokenizer.batch_decode(dec_response))
            print(dec_reward)
            print()

            if self.multi_agent:

                query = list(enc_query) + list(dec_query)
                response = list(enc_response) + list(dec_response)
                reward = list(enc_reward) + list(dec_reward)

                stats = self.ppo_trainer.step(query, response, reward)

            else:
                stats = self.ppo_trainer.step(list(enc_query), list(enc_response), list(enc_reward), list(dec_query), list(dec_response), list(dec_reward))

            self.log_stats(stats, dec_reward)
        

In [48]:
type(model_ref)

trl_custom.models.modeling_value_head.AutoModelForCausalLMWithValueHead

In [2]:
multi_agent = False

if multi_agent:
    from trl import PPOTrainer
else:
    from trl_custom import PPOTrainer

device = 0 if torch.cuda.is_available() else 'cpu'

In [42]:
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2-large').to(device)
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2-large').to(device)
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
tokenizer.pad_token = tokenizer.eos_token

In [46]:
#wandb.init(project="my-awesome-project")

config = {
    'batch_size': 16,
    'learning_rate': 1e-6,
    'steps': 48,
    'episodes': 1000,
    'device': device,
    'multi_agent': multi_agent
}

steg_trainer = StegPPOTrainer(config, model, model_ref, tokenizer)



In [47]:
steg_trainer.train()

key batch torch.Size([16, 1])
decoder_response:  tensor([[657],
        [657],
        [513],
        [767],
        [362],
        [352],
        [362],
        [352],
        [642],
        [642],
        [657],
        [767],
        [352],
        [657],
        [604],
        [657]], device='cuda:0')
self key batch:  tensor([[718],
        [807],
        [604],
        [513],
        [718],
        [860],
        [352],
        [604],
        [604],
        [352],
        [860],
        [860],
        [860],
        [657],
        [352],
        [362]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', 'What is your favorite ', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 'The weath



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [362],
        [352],
        [352],
        [807],
        [657],
        [352],
        [362],
        [604],
        [362],
        [352],
        [513],
        [642],
        [860],
        [352],
        [352]], device='cuda:0')
self key batch:  tensor([[642],
        [657],
        [807],
        [362],
        [807],
        [362],
        [767],
        [642],
        [657],
        [657],
        [807],
        [352],
        [860],
        [718],
        [352],
        [657]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['The weather today is ', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', 'Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 'What is your favorite ', 'The weather today is ', 'What is you



key batch torch.Size([16, 1])
decoder_response:  tensor([[604],
        [352],
        [657],
        [807],
        [352],
        [718],
        [657],
        [642],
        [352],
        [657],
        [657],
        [362],
        [352],
        [807],
        [513],
        [657]], device='cuda:0')
self key batch:  tensor([[657],
        [718],
        [767],
        [718],
        [860],
        [352],
        [718],
        [604],
        [352],
        [657],
        [642],
        [362],
        [642],
        [807],
        [352],
        [807]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', ' 0 7 3 8 4', 'The weather today is ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', 'What is your favorite ', 'What is your favorite ', 



key batch torch.Size([16, 1])
decoder_response:  tensor([[767],
        [657],
        [767],
        [860],
        [657],
        [352],
        [352],
        [657],
        [352],
        [807],
        [657],
        [362],
        [807],
        [362],
        [657],
        [657]], device='cuda:0')
self key batch:  tensor([[657],
        [362],
        [642],
        [352],
        [362],
        [807],
        [362],
        [604],
        [807],
        [513],
        [718],
        [642],
        [807],
        [807],
        [657],
        [642]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['The weather today is ', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'The weather today is ', 'What is your favorite ', 'Yesterday I went to ', 



key batch torch.Size([16, 1])
decoder_response:  tensor([[657],
        [657],
        [513],
        [352],
        [513],
        [807],
        [352],
        [642],
        [642],
        [352],
        [352],
        [513],
        [352],
        [657],
        [513],
        [718]], device='cuda:0')
self key batch:  tensor([[767],
        [513],
        [807],
        [657],
        [767],
        [642],
        [807],
        [362],
        [604],
        [767],
        [513],
        [362],
        [642],
        [642],
        [513],
        [860]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', ' 0 7 3 8 4', ' 0 7 3 8 4', ' 0 7 3 8 4', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', 'Yesterday I went to ', 'Wh



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [657],
        [642],
        [352],
        [604],
        [604],
        [352],
        [352],
        [657],
        [642],
        [352],
        [352],
        [860],
        [513],
        [352],
        [352]], device='cuda:0')
self key batch:  tensor([[718],
        [642],
        [352],
        [513],
        [642],
        [642],
        [513],
        [718],
        [362],
        [767],
        [642],
        [767],
        [767],
        [513],
        [352],
        [642]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'The weather today is ', 'The weather today is ', 'What is your favorite ', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ',



key batch torch.Size([16, 1])
decoder_response:  tensor([[362],
        [352],
        [604],
        [807],
        [352],
        [352],
        [513],
        [352],
        [604],
        [352],
        [604],
        [352],
        [352],
        [352],
        [657],
        [604]], device='cuda:0')
self key batch:  tensor([[642],
        [604],
        [657],
        [604],
        [807],
        [642],
        [767],
        [767],
        [718],
        [860],
        [362],
        [642],
        [718],
        [657],
        [657],
        [604]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'The weather today is ', 'The weather today is ', ' 0 7 3 8 4', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', 'Yesterday I went to ', 'The weather t



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [657],
        [352],
        [352],
        [657],
        [352],
        [513],
        [657],
        [352],
        [657],
        [352],
        [657],
        [657],
        [604],
        [718],
        [657]], device='cuda:0')
self key batch:  tensor([[604],
        [767],
        [807],
        [604],
        [767],
        [513],
        [352],
        [642],
        [513],
        [657],
        [642],
        [513],
        [513],
        [657],
        [807],
        [767]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'Yesterday I went to ', 'Yesterday I went to ', 'The weather today is ', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', 'What is your favorite ', 'The weather today is ', 'Th



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [513],
        [604],
        [718],
        [362],
        [657],
        [767],
        [352],
        [657],
        [657],
        [807],
        [352],
        [513],
        [657],
        [352],
        [352]], device='cuda:0')
self key batch:  tensor([[642],
        [604],
        [362],
        [642],
        [807],
        [513],
        [807],
        [657],
        [352],
        [362],
        [657],
        [362],
        [362],
        [807],
        [718],
        [604]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['What is your favorite ', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went 



key batch torch.Size([16, 1])
decoder_response:  tensor([[657],
        [657],
        [807],
        [352],
        [657],
        [362],
        [362],
        [657],
        [352],
        [352],
        [604],
        [352],
        [657],
        [352],
        [352],
        [860]], device='cuda:0')
self key batch:  tensor([[513],
        [807],
        [352],
        [604],
        [767],
        [513],
        [718],
        [362],
        [657],
        [352],
        [513],
        [513],
        [718],
        [362],
        [860],
        [718]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'The weather today is ', ' 0 7 3 8 4', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', 'Yesterday I 



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [604],
        [352],
        [657],
        [362],
        [657],
        [807],
        [352],
        [657],
        [657],
        [657],
        [604],
        [657],
        [513],
        [513],
        [352]], device='cuda:0')
self key batch:  tensor([[767],
        [657],
        [657],
        [604],
        [604],
        [718],
        [513],
        [352],
        [718],
        [807],
        [604],
        [604],
        [642],
        [767],
        [807],
        [657]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', ' 0 7 3 8 4']
[' 7', 



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [352],
        [767],
        [657],
        [642],
        [657],
        [642],
        [767],
        [767],
        [657],
        [352],
        [657],
        [352],
        [352],
        [718],
        [604]], device='cuda:0')
self key batch:  tensor([[807],
        [362],
        [718],
        [604],
        [362],
        [513],
        [807],
        [642],
        [718],
        [657],
        [642],
        [767],
        [352],
        [513],
        [807],
        [604]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', 'The weather today is ', 'Yesterday I w



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [513],
        [657],
        [642],
        [352],
        [657],
        [657],
        [657],
        [352],
        [657],
        [362],
        [352],
        [513],
        [352],
        [604],
        [657]], device='cuda:0')
self key batch:  tensor([[642],
        [362],
        [807],
        [657],
        [657],
        [807],
        [860],
        [767],
        [513],
        [513],
        [642],
        [860],
        [513],
        [352],
        [362],
        [362]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['The weather today is ', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', ' 0 7 3 8 4', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'Yesterday I went to ', ' 0 7 3 8 4', '



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [657],
        [642],
        [362],
        [718],
        [657],
        [642],
        [657],
        [352],
        [718],
        [352],
        [604],
        [642],
        [657],
        [362],
        [767]], device='cuda:0')
self key batch:  tensor([[362],
        [767],
        [657],
        [767],
        [352],
        [657],
        [767],
        [362],
        [807],
        [657],
        [642],
        [718],
        [513],
        [657],
        [767],
        [513]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['What is your favorite ', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'What is your favorite ', 'The weather today is '



key batch torch.Size([16, 1])
decoder_response:  tensor([[767],
        [352],
        [352],
        [352],
        [352],
        [642],
        [352],
        [362],
        [807],
        [807],
        [362],
        [362],
        [657],
        [657],
        [657],
        [352]], device='cuda:0')
self key batch:  tensor([[657],
        [604],
        [642],
        [807],
        [657],
        [513],
        [657],
        [604],
        [718],
        [513],
        [718],
        [513],
        [718],
        [718],
        [362],
        [718]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'Yesterday I went to ', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', 'What is your favori



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [352],
        [657],
        [513],
        [657],
        [657],
        [352],
        [604],
        [352],
        [657],
        [513],
        [352],
        [657],
        [513],
        [657],
        [352]], device='cuda:0')
self key batch:  tensor([[362],
        [860],
        [718],
        [657],
        [352],
        [718],
        [718],
        [604],
        [767],
        [513],
        [718],
        [604],
        [604],
        [513],
        [718],
        [807]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', 'The weather today is ', 'The weather today is ', ' 0 7 3 8 4', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'The weather tod



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [604],
        [352],
        [642],
        [657],
        [657],
        [362],
        [657],
        [352],
        [352],
        [807],
        [352],
        [352],
        [657],
        [352],
        [362]], device='cuda:0')
self key batch:  tensor([[767],
        [513],
        [642],
        [513],
        [352],
        [513],
        [718],
        [657],
        [657],
        [352],
        [513],
        [860],
        [767],
        [657],
        [657],
        [513]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'What is your favorite ', 'The weather today is ', 'The weather today is ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'The weather today is ', 'Yesterday I went to ', 'Ye



key batch torch.Size([16, 1])
decoder_response:  tensor([[604],
        [362],
        [513],
        [513],
        [352],
        [642],
        [657],
        [657],
        [352],
        [657],
        [718],
        [352],
        [352],
        [657],
        [352],
        [767]], device='cuda:0')
self key batch:  tensor([[718],
        [362],
        [513],
        [352],
        [642],
        [604],
        [718],
        [860],
        [767],
        [362],
        [807],
        [657],
        [642],
        [860],
        [604],
        [767]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'The weather today is ', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', 'The w



key batch torch.Size([16, 1])
decoder_response:  tensor([[657],
        [352],
        [657],
        [657],
        [657],
        [642],
        [657],
        [657],
        [513],
        [642],
        [352],
        [352],
        [657],
        [657],
        [362],
        [642]], device='cuda:0')
self key batch:  tensor([[718],
        [352],
        [362],
        [362],
        [642],
        [513],
        [657],
        [718],
        [860],
        [604],
        [767],
        [362],
        [642],
        [657],
        [767],
        [642]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['The weather today is ', 'The weather today is ', ' 0 7 3 8 4', 'The weather today is ', 'What is your favorite ', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'The weather today is ', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4',



key batch torch.Size([16, 1])
decoder_response:  tensor([[718],
        [657],
        [352],
        [352],
        [767],
        [807],
        [352],
        [352],
        [513],
        [513],
        [352],
        [657],
        [352],
        [352],
        [657],
        [352]], device='cuda:0')
self key batch:  tensor([[807],
        [642],
        [718],
        [604],
        [807],
        [513],
        [767],
        [352],
        [657],
        [657],
        [352],
        [604],
        [642],
        [513],
        [362],
        [362]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['The weather today is ', 'What is your favorite ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'The weather today is ', 'The weather today is ', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'Yesterday I went to ', 'Yesterday I went to ', ' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', 'The weather today is ']




key batch torch.Size([16, 1])
decoder_response:  tensor([[657],
        [718],
        [642],
        [657],
        [642],
        [513],
        [352],
        [767],
        [352],
        [362],
        [352],
        [352],
        [513],
        [657],
        [657],
        [352]], device='cuda:0')
self key batch:  tensor([[718],
        [513],
        [718],
        [604],
        [657],
        [642],
        [513],
        [718],
        [513],
        [604],
        [513],
        [807],
        [860],
        [767],
        [718],
        [352]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
['What is your favorite ', 'Yesterday I went to ', 'The weather today is ', ' 0 7 3 8 4', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', 'What is your favorite ', 'Yesterday I went to ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'The weather t



key batch torch.Size([16, 1])
decoder_response:  tensor([[352],
        [352],
        [807],
        [657],
        [352],
        [352],
        [657],
        [513],
        [657],
        [352],
        [807],
        [362],
        [657],
        [513],
        [352],
        [657]], device='cuda:0')
self key batch:  tensor([[642],
        [362],
        [807],
        [767],
        [513],
        [718],
        [807],
        [807],
        [767],
        [860],
        [604],
        [362],
        [807],
        [807],
        [657],
        [718]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'Yesterday I went to ', 'Yesterday I went to ', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', 'Yesterday I went to ', 'What is your favorite ', 'Yesterday I went to ', 'What is your favorite ', 'What is your favorite ', 'The weather today is ', ' 0 7 3 8 4', 'What is your favorite ', 



key batch torch.Size([16, 1])
decoder_response:  tensor([[718],
        [352],
        [767],
        [718],
        [860],
        [657],
        [657],
        [642],
        [352],
        [807],
        [604],
        [604],
        [352],
        [352],
        [362],
        [513]], device='cuda:0')
self key batch:  tensor([[513],
        [657],
        [718],
        [860],
        [604],
        [718],
        [860],
        [718],
        [860],
        [642],
        [657],
        [604],
        [513],
        [352],
        [767],
        [807]], device='cuda:0')
-----------------------------------------------------------------------
prompt, keys:
[' 0 7 3 8 4', 'Yesterday I went to ', 'The weather today is ', 'What is your favorite ', 'The weather today is ', 'What is your favorite ', 'Yesterday I went to ', ' 0 7 3 8 4', 'What is your favorite ', 'What is your favorite ', ' 0 7 3 8 4', 'Yesterday I went to ', ' 0 7 3 8 4', ' 0 7 3 8 4', 'The weather today is ', 'Yesterday



key batch torch.Size([16, 1])



KeyboardInterrupt



In [34]:
print(steg_trainer.env.tokenizer.encode('The weather today is '))

[464, 6193, 1909, 318, 220]


In [40]:
print(len(str(steg_trainer.env.tokenizer.decode([220]))))

1
