# Network Creation

In [68]:
!pip install torch



In [69]:
import torch
import matplotlib.pyplot as plt
from torch.distributions.categorical import Categorical

In [70]:
from torch import nn

class ActorNetwork(nn.Module):
    """LSTM RNN for generating words for wordle solver"""
    def __init__(self, input_size, hidden_size, projection=1, num_layers=5, dropout_probability=0):
        super().__init__()
        self.network = nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout_probability, proj_size=projection, batch_first=True)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        output, (hidden, _) = self.network(x)
        probab = hidden + 1 / 2  # takes output of softmax [-1, 1] and squish to [0, 1]
        return probab 

class CriticNet(nn.Module):
    """Network representing the critic"""
    def __init__(self, in_shape):
        super().__init__()
        self.v_network = nn.Sequential(
            nn.Linear(in_shape, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
            )

    def forward(self, x):
        return self.v_network(x)

# Toy Problem for RNN

In [71]:
net = ActorNetwork(1, 32, 26, 5)
input = torch.tensor([[
    [1],
    [2],
    [3],
    [4],
    [5]
]]).float()
print(input.shape)
net(input)

torch.Size([1, 5, 1])


tensor([[[0.4656, 0.3937, 0.4987, 0.3833, 0.4901, 0.4142, 0.4987, 0.6309,
          0.3831, 0.5823, 0.4060, 0.5631, 0.4866, 0.4684, 0.4633, 0.4279,
          0.4047, 0.3852, 0.5362, 0.5652, 0.4464, 0.7684, 0.5890, 0.3515,
          0.5704, 0.6449]],

        [[0.4841, 0.5214, 0.4661, 0.5200, 0.4924, 0.5212, 0.4898, 0.5571,
          0.4689, 0.5300, 0.4880, 0.5031, 0.5181, 0.4707, 0.4217, 0.5299,
          0.5038, 0.4702, 0.5230, 0.5298, 0.4833, 0.5319, 0.4852, 0.4545,
          0.5095, 0.4580]],

        [[0.4411, 0.4790, 0.4566, 0.5235, 0.4340, 0.5517, 0.4228, 0.4405,
          0.5119, 0.4916, 0.4878, 0.4330, 0.5193, 0.5714, 0.4672, 0.5206,
          0.5194, 0.5121, 0.4927, 0.4860, 0.4905, 0.6324, 0.5537, 0.5567,
          0.4364, 0.4717]],

        [[0.4916, 0.4940, 0.5514, 0.5895, 0.4329, 0.5074, 0.4269, 0.5238,
          0.4813, 0.4677, 0.5173, 0.5753, 0.5110, 0.5180, 0.4847, 0.4367,
          0.5213, 0.4474, 0.5367, 0.4359, 0.5049, 0.5063, 0.4869, 0.5569,
          0.4791, 0.4537]

# Figuring out the environment

In [72]:
import gym
import gym_wordle
wordle = gym.make('Wordle-v0')

In [73]:
obs1 = wordle.reset()
print(obs1)
obs = wordle.step([2, 17, 0, 19, 4])
wordle.render()
print(obs)
wordle.hidden_word

{'board': array([[-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1]]), 'alphabet': array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1])}
###################################################
c r a t e 

a b c d e f g h i j k l m n o p q r s t u v w x y z 
###################################################

({'board': array([[ 1,  0,  1,  0,  1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1]]), 'alphabet': array([ 1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        0, -1,  0, -1, -1, -1, -1, -1, -1])}, 0.0, False, {})


(15, 0, 2, 4, 18)

In [74]:
wordle.step([17, 4, 1, 0, 1])

({'board': array([[ 1,  0,  1,  0,  1],
         [ 0,  1,  0,  1,  0],
         [-1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1]]),
  'alphabet': array([ 1,  0,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          0, -1,  0, -1, -1, -1, -1, -1, -1])},
 0.0,
 False,
 {})

In [75]:
import numpy as np
rnn = ActorNetwork(5, 32, 26, 5)
o = np.array([obs[0]['board']])
o = torch.from_numpy(o).to(torch.float32)
print(o.shape)
print(rnn(o))

torch.Size([1, 6, 5])
tensor([[[0.5853, 0.4456, 0.4240, 0.4744, 0.5212, 0.5089, 0.4545, 0.4641,
          0.4128, 0.4321, 0.5713, 0.5634, 0.5272, 0.6250, 0.5229, 0.4626,
          0.5495, 0.5620, 0.4732, 0.5064, 0.4168, 0.6081, 0.4763, 0.4224,
          0.5129, 0.4923]],

        [[0.4497, 0.5023, 0.5358, 0.4692, 0.4782, 0.4542, 0.5261, 0.5350,
          0.4895, 0.4837, 0.5212, 0.5068, 0.5567, 0.4713, 0.5391, 0.4520,
          0.5167, 0.5069, 0.4726, 0.5571, 0.4997, 0.5072, 0.4409, 0.4582,
          0.5108, 0.5212]],

        [[0.4696, 0.4925, 0.4252, 0.5290, 0.5100, 0.5467, 0.4834, 0.4586,
          0.4226, 0.4837, 0.5285, 0.4993, 0.5134, 0.4231, 0.4727, 0.4910,
          0.4560, 0.5456, 0.4733, 0.5646, 0.5257, 0.4949, 0.4416, 0.4440,
          0.5177, 0.5031]],

        [[0.4664, 0.5043, 0.4721, 0.5029, 0.5818, 0.5272, 0.4602, 0.5432,
          0.5220, 0.5243, 0.5182, 0.5130, 0.5495, 0.4329, 0.5482, 0.4737,
          0.4859, 0.4792, 0.4899, 0.4898, 0.5207, 0.5119, 0.4858, 0.5332,
   

In [76]:
prob = Categorical(rnn(o))
prob.sample()

tensor([[10],
        [18],
        [15],
        [ 1],
        [10]])

# Competing A2C solvers

In [77]:
def train(replay, q_val):
    vals = torch.vstack(replay.vals)
    
    rewards_tensor = torch.tensor(np.asarray(replay.rewards, dtype=np.float32))
    is_terminal_tensor = torch.tensor(np.asarray(replay.dones, dtype=np.int32))
        
    q_vals_tensor = rewards_tensor + discount_factor * q_val * (1 - is_terminal_tensor)
        
    advantage = q_vals_tensor - vals
    
    critic_loss = (advantage ** 2).mean()
    # critic_loss.requires_grad = True
    adam_critic.zero_grad()
    critic_loss.backward(retain_graph=True)
    adam_critic.step()
    
    
    log_probabs_tensor = torch.vstack(replay.log_probs)
    actor_loss = (-log_probabs_tensor*advantage.detach()).mean()
    # actor_loss.requires_grad = True
    adam_actor.zero_grad()
    actor_loss.backward()
    adam_actor.step()

In [78]:
class Advantage_ActorCritic():
    def __init__(self, world: gym.Env, policy_net: ActorNetwork, critic_net: CriticNet,
                 encoder, policy_alpha, critic_alpha, gamma, max_reward):
        # environment info
        self.world = world
        self.encoder = encoder
        self.max_reward = max_reward

        # actor and critic
        self.actor = policy_net
        self.critic = critic_net
        self.error_buffer = list()
        self.policy_optimizer = torch.optim.Adam(policy_net.network.parameters(), lr=policy_alpha)
        self.v_optimizer = torch.optim.Adam(critic_net.v_network.parameters(), lr=critic_alpha)

        # training info
        self.gamma = gamma
        self.episodes = 0

    def train(self, iterations):
        converged = False
        rewards = list()
        recents = torch.zeros(10)
        i = 0
        while not converged:
            r = self.episode()
            rewards.append(r)

            if i < 10:
                recents[i] = r
                i += 1

            else:
                recents.roll(-1, 0)
                recents[9] = r
            
            # convergence check
            if len(rewards) > 10 and recents.mean() >= self.max_reward:
                converged = True
            converged = True if iterations == self.episodes else False
        
        plt.plot(rewards)
        plt.show()


    def episode(self, training=True):
        done = False
        state = self.world.reset().copy()
        episode_reward = 0
        self.episodes += 1
        while not done:
            # take on policy action
            encoded = self.encoder(state)
            net_out = self.actor(encoded)

            dist = Categorical(net_out)
            action = dist.sample()
            try:
                state_prime, reward, done, _ = self.world.step(action)
            except AssertionError:
                reward = -.1
                state_prime = state.copy()

            # fill buffer
            self.error_buffer.append((state, state_prime, reward, dist.log_prob(action)))

            # prepare for next iteration
            episode_reward += reward
            state = state_prime.copy()
            if training:
                self.__net_update()

        return episode_reward

    # Critic Loss
    def td_error(self, value, value_prime, reward):
        def loss_fn():
            print(value.shape, value_prime.shape) 
            return reward + self.gamma * value_prime - value
        return loss_fn()

    # Actor error
    def policy_error(self, prob, error):
        # def loss_fn():
        print(prob.shape)
        print(error.shape)
        log_probabs_tensor = prob.reshape(-1)
        print(log_probabs_tensor.shape)
        return (-log_probabs_tensor * error).mean()
        # return loss_fn()

    def __net_update(self):
        # calculate error for policy and critic
        self.critic.v_network.zero_grad()
        self.actor.network.zero_grad()

        state, state_prime, reward, prob = self.error_buffer.pop(0)
        print(f'Encoder -> {self.encoder(state)}')
        value = self.critic(self.encoder(state))
        value_prime = self.critic(self.encoder(state_prime))

        td_error = self.td_error(value, value_prime, reward)
        print(f'TD Error: {td_error}')
        print(f'Prob: {prob}')


        policy_error = self.policy_error(prob, td_error) 
        print(f'P Error: {policy_error.reshape(-1)}')

        # backprop
        if self.episodes > 100:
            td_error.backward()
        else:
            policy_error.backward()
            self.policy_optimizer.step()
        self.v_optimizer.step()

# Solving the Environment

In [79]:
critic = CriticNet(30)
actor = ActorNetwork(30, 32, 26)
agent = Advantage_ActorCritic(wordle, actor, critic, 
lambda x: torch.tensor([x['board']]).to(torch.float32).reshape(1, 1,30), 1e-3, 1e-3, 0.9, 1)

In [80]:
agent.episode()

Encoder -> tensor([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1.]]])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
TD Error: tensor([[[-0.1257]]], grad_fn=<SubBackward0>)
Prob: tensor([[-3.3226],
        [-3.1932],
        [-3.2705],
        [-3.2310],
        [-3.2165]], grad_fn=<SqueezeBackward1>)
torch.Size([5, 1])
torch.Size([1, 1, 1])
torch.Size([5])
P Error: tensor([-0.4082], grad_fn=<ReshapeAliasBackward0>)
Encoder -> tensor([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1.]]])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
TD Error: tensor([[[-0.1333]]], grad_fn=<SubBackward0>)
Prob: tensor([[-3.1066],
        [-3.2001],
        [-3.2017],
        [-3.2551],
        [-3.2366]], grad_fn=<SqueezeBackward1>)
torch.Size([5, 1])
torch.Size([1, 1

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

