In [1]:
import torch
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from gymnasium import Env, spaces
from tiny_shakespeare import TinyShakespeareDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt

from token_level_ppo import ThoughtsFormerPolicy, TokenLevelPPO


def token_batched_reshape_with_offset(x: torch.Tensor, max_seq_length: int, thoughts_taken: int) -> torch.Tensor:
    thoughts = thoughts_taken + 1
    max_thoughts = x.size(1) // max_seq_length
    x = x[:,:max_seq_length*thoughts].view(x.size(0), max_seq_length, thoughts)
    return F.pad(x,(0, (max_thoughts - thoughts)))
        
class ThoughtsFormerEnv(Env):
    def __init__(self, vocab_size, max_sequence_length, max_thought_length):
        super(ThoughtsFormerEnv, self).__init__()
        self.vocab_size = vocab_size
        self.max_sequence_length = max_sequence_length 
        self.max_thought_length = max_thought_length
        self.max_context_length = max_sequence_length * (max_thought_length+1)
        
        # Logits
        self.action_space = spaces.Box(low=-100, high=100, shape=(max_sequence_length,vocab_size), dtype=np.float32)

        
        self.observation_space = spaces.Dict({
            "state" : spaces.MultiDiscrete([vocab_size] * self.max_context_length),
            "thought_step" : spaces.Discrete(max_thought_length+1)
        })
        
        self.dataset = TinyShakespeareDataset(max_sequence_length,window_offset=max_sequence_length//4)
        self.dataset_len = len(self.dataset)
        self.dataset_iter = 0
        
        self.thought_step = 0

    def reset(self, seed=None):
        super().reset(seed=seed)  # Ensures Gymnasium's seeding is properly handled
        
        self.state, self.labels = self.dataset[self.dataset_iter]
        self.state = F.pad(self.state, (0,self.max_context_length-self.max_sequence_length))
     
        # prepare self.state and massively elongate
        self.dataset_iter += 1
        if self.dataset == self.dataset_len:
            self.dataset_iter = 0

        self.thought_step = 0
        
        obs = {
            'state' : self.state.numpy(),
            'thought_step' : self.thought_step
        }
        return obs, {}

    def step(self, action):
        self.state = self.state.view(1,-1)
        
        probs = F.softmax(torch.from_numpy(action), dim = -1)
        sampled_tokens = torch.multinomial(
            probs.view(-1, probs.size(-1)),
            num_samples=1
        ).view(-1, probs.size(0))
            
        
        if self.thought_step == self.max_thought_length:
            reward = self.reward(action)
            # print(reward.shape)
            done = True
        else:
            reward = torch.zeros(self.max_sequence_length)
            done = False
            # Add the thought!
            self.state = token_batched_reshape_with_offset(self.state, self.max_sequence_length, self.thought_step) # (batch x max_seq_len, (max_thought_len+1))
            # before
            # plt.imshow(self.state[0,:10]); plt.show()
            
            self.state[:,:,self.thought_step+1] = sampled_tokens
            # plt.imshow(self.state[0,:10]); plt.show()
            self.state =  self.state.view(1,-1)
        
        self.state = self.state.view(-1)
        obs = {
            'state' : self.state.numpy(),
            'thought_step' : self.thought_step
        }
        
        info = {'reward' : reward, 'actions_taken' : sampled_tokens.squeeze(0)} 
        
        # print(info)
        self.thought_step += 1
        return obs, -np.inf, done, False, info #obs, reward, done, truncated, info

    def reward(self, action):
        return -F.cross_entropy(torch.tensor(action), self.labels, reduction='none')
    

env = ThoughtsFormerEnv(vocab_size=50257, max_sequence_length=512,max_thought_length=1)

# ppo = TokenLevelPPO(ThoughtsFormerPolicy, env, n_steps=4, batch_size=2, max_sequence_length=512, verbose=2)


  fn()
Token indices sequence length is longer than the specified maximum sequence length for this model (301966 > 1024). Running this sequence through the model will result in indexing errors


In [2]:
ppo = TokenLevelPPO.load(max_sequence_length=512, path="ppo_thoughtsformer2", env=env, n_steps=4,batch_size=2)

Exception: code() argument 13 must be str, not int
Exception: code() argument 13 must be str, not int


Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




In [46]:
ppo.learn(1)    

KeyboardInterrupt: 

In [30]:
x = TinyShakespeareDataset(512,1024*4)

Token indices sequence length is longer than the specified maximum sequence length for this model (301966 > 1024). Running this sequence through the model will result in indexing errors


In [31]:
obs = {'state' : x[0][0].view(1,-1).to('cuda'), 'thought_step' : 1}
a,b = ppo.policy.forward(obs)

In [39]:
n = F.softmax(a,dim=-1)
n[:,:,n.argmax(dim=-1).flatten()]

tensor([[[0.0055, 0.0048, 0.0055,  ..., 0.0021, 0.0021, 0.0021],
         [0.0047, 0.0052, 0.0047,  ..., 0.0024, 0.0024, 0.0024],
         [0.0053, 0.0051, 0.0053,  ..., 0.0027, 0.0027, 0.0027],
         ...,
         [0.0104, 0.0040, 0.0104,  ..., 0.0233, 0.0233, 0.0233],
         [0.0106, 0.0039, 0.0106,  ..., 0.0228, 0.0228, 0.0228],
         [0.0110, 0.0038, 0.0110,  ..., 0.0236, 0.0236, 0.0236]]],
       device='cuda:0', grad_fn=<IndexBackward0>)

In [43]:
from transformers import GPT2Tokenizer
p = GPT2Tokenizer.from_pretrained("gpt2")



In [45]:
p.decode(x)

','

In [27]:
logger.name_to_value

defaultdict(float, {})

In [2]:
import torch
import torch.nn as nn
from thoughtsformer import ThoughtsFormer, simple_batched_reshape_with_offset
from stable_baselines3.common.policies import ActorCriticPolicy

class _ThoughtsFormerPolicy(ActorCriticPolicy):
    def __init__(self):
        self.model = ThoughtsFormer.from_pretrained_GPT2()
        
    def forward(self, obs: dict):
        assert 'state' in obs and 'thought_step' in obs, f"'state' and 'thought_step' should be keys of the observation dictionary"
        return self.model.forward_ppo_with_tokens(obs['state'], torch.zeros_like(obs['state']), obs['thought_step'])
    
    def _get_action_dist_from_obs(self, obs):
        return self.forward(obs)[0]

    def _get_value_from_obs(self, obs):
        return self.forward(obs)[1]
    
env = ThoughtsFormerEnv()

TypeError: ThoughtsFormerEnv.__init__() missing 3 required positional arguments: 'vocab_size', 'max_sequence_length', and 'max_thought_length'

In [5]:

# Check the environment to ensure compatibility
from token_level_ppo import TokenLevelPPO, TokenLevelRolloutBuffer, ThoughtsFormerPolicy
ppo = TokenLevelPPO(ThoughtsFormerPolicy, env, n_steps=2, batch_size=2, max_sequence_length=512, verbose=2)


state torch.Size([11264])


TypeError: TokenLevelPPO.collect_rollouts() got an unexpected keyword argument 'n_rollout_steps'

In [13]:
import torch.nn.functional as F
x = torch.tensor([1,2,3,4,0,0,0,0,0,0,0,0,0,0,0,0]).view(1,-1)


def simple_batched_reshape_with_offset(x: torch.Tensor, max_seq_length: int, thoughts_taken: int) -> torch.Tensor:
    thoughts = thoughts_taken + 1
    max_thoughts = x.size(1) // max_seq_length
    x = x[:,:max_seq_length*thoughts].view(x.size(0), max_seq_length, thoughts)
    return F.pad(x,(0, (max_thoughts - thoughts)))

simple_batched_reshape_with_offset(x,4,1)

tensor([[[1, 2, 0, 0],
         [3, 4, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]]])

In [9]:
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from gymnasium import Env, spaces

class CustomEnv(Env):
    def __init__(self):
        super(CustomEnv, self).__init__()
        # Define action and observation space
        self.action_space = spaces.Discrete(4)  # 4 actions (up, right, down, left)
        self.observation_space = spaces.MultiDiscrete([3, 4])  # (y, x) coordinates
        self.state = (2, 0)  # Initial state

    def reset(self, seed=None):
        super().reset(seed=seed)  # Ensures Gymnasium's seeding is properly handled
        """Resets the environment to the initial state."""
        self.state = (2, 0)
        return np.array(self.state), {}

    def step(self, action):
        """Take an action in the environment."""
        y, x = self.state

        if action == 0:  # up
            y += 1
        elif action == 1:  # right
            x += 1
        elif action == 2:  # down
            y -= 1
        elif action == 3:  # left
            x -= 1

        # Keep coordinates within bounds
        y, x = max(0, min(2, y)), max(0, min(3, x))
        self.state = (y, x)

        # Compute reward and done
        reward, done = self.reward(self.state
                                   )
        return np.array(self.state), reward, done, False, {}

    def reward(self, state):
        if state == (0, 3):
            return 1, True  # Goal state with reward
        elif state == (1, 3):
            return -1, True  # Failure state with negative reward
        else:
            return 0, False  # No reward, episode continues

# Check the environment to ensure compatibility
env = CustomEnv()
check_env(env)

# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=1000)

# Test the trained model
obs = env.reset()[0]
done = False
while not done:
    action, _states = model.predict(obs)
    obs, reward, done, _, info = env.step(action)
    print(f"State: {obs}, Reward: {reward}")

env.close()


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 31.4     |
|    ep_rew_mean     | -0.292   |
| time/              |          |
|    fps             | 327      |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 2048     |
---------------------------------
State: [2 1], Reward: 0
State: [1 1], Reward: 0
State: [1 2], Reward: 0
State: [1 1], Reward: 0
State: [1 0], Reward: 0
State: [0 0], Reward: 0
State: [0 1], Reward: 0
State: [0 1], Reward: 0
State: [0 1], Reward: 0
State: [0 0], Reward: 0
State: [0 0], Reward: 0
State: [0 0], Reward: 0
State: [0 0], Reward: 0
State: [0 0], Reward: 0
State: [0 1], Reward: 0
State: [0 1], Reward: 0
State: [0 2], Reward: 0
State: [0 1], Reward: 0
State: [0 2], Reward: 0
State: [0 1], Reward: 0
State: [1 1], Reward: 0
State: [1 2], Reward: 0
State: [0 2], Reward: 0
State: [1 2],

In [110]:
value = np.zeros((3,4))
policy = np.zeros((3,4))

In [None]:
import numpy as np
from enum import Enum
class Actions(Enum):
    up = 0,
    right = 1,
    down = 2,
    left  = 3,
    none = -1

class env():
    
    actions = (Actions.up, Actions.right, Actions.down, Actions.left)
    states = (
            ((0,0), (0,1), (0,2), (0,3)),
            ((1,0), (1,1), (1,2), (1,3)),
            ((2,0), (2,1), (2,2), (2,3))
    )
    
    iter_states = (
            (0,0), (0,1), (0,2), (0,3),
            (1,0), (1,1), (1,2), (1,3),
            (2,0), (2,1), (2,2), (2,3)
    )
    
    def __init__(self):
        pass
    def step(self, state, action):
        y, x = state
        
        if action == Actions.up:
            y += 1
        if action == Actions.right:
            x += 1
        if action == Actions.down:
            y -= 1
        if action == Actions.left:
            x -= 1

        y, x = max(y,0), max(x,0)
        y, x = min(y,2), min(x,3)
        state_next = (y,x)
        reward, done = self.reward((y,x))
        return state_next, reward, done
        
    def reward(self, next_state):
        if next_state == (0,3):
            return 1, True
        elif next_state == (1,3):
            return -1, True
        else:
            return 0, False

In [113]:
H = 3
gamma = 0.9
e = env()
for i in range(H):
    for state in env.iter_states:
        
        if state == (0, 3) or state == (1, 3):
            continue
        
        if state == (1, 1):
            continue
        
        max_reward = 0
        p = Actions.none.value
        for idx, action in enumerate(env.actions):
            n_state, rew = e.step(state, action)
            n = rew + gamma * value[n_state]
            if n > max_reward:
                max_reward = n
                p = action.value[0]
        policy[state] = p
        value[state] = max_reward

In [114]:
value

array([[0.81  , 0.9   , 1.    , 0.    ],
       [0.729 , 0.    , 0.9   , 0.    ],
       [0.6561, 0.729 , 0.81  , 0.729 ]])

In [104]:
policy

class Actions(Enum):
    down = 0,
    right = 1,
    up = 2,
    left  = 3,
    none = -1


array([[1., 1., 1., 0.],
       [2., 0., 2., 0.],
       [1., 1., 2., 3.]])