# Network Creation

In [22]:
!pip install torch



In [23]:
import torch
import matplotlib.pyplot as plt
from torch.distributions.categorical import Categorical



In [24]:
from torch import nn

class ActorNetwork(nn.Module):
    """LSTM RNN for generating words for wordle solver"""
    def __init__(self, input_size, hidden_size, projection=1, num_layers=5, dropout_probability=0):
        super().__init__()
        self.network = nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout_probability, proj_size=projection, batch_first=True)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        output, (hidden, _) = self.network(x)
        probab = hidden + 1 / 2  # takes output of softmax [-1, 1] and squish to [0, 1]
        return probab 

class CriticNet(nn.Module):
    """Network representing the critic"""
    def __init__(self, in_shape):
        super().__init__()
        self.v_network = nn.Sequential(
            nn.Linear(in_shape, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
)

    def forward(self, x):
        return self.v_network(x)

# Toy Problem for RNN

In [25]:
net = ActorNetwork(1, 32, 26, 5)
input = torch.tensor([[
    [1],
    [2],
    [3],
    [4],
    [5]
]]).float()
print(input.shape)
net(input)

torch.Size([1, 5, 1])


tensor([[[0.3039, 0.6392, 0.2936, 0.6497, 0.6121, 0.7042, 0.4416, 0.5929,
          0.5138, 0.3772, 0.3447, 0.3463, 0.3341, 0.5110, 0.3609, 0.3104,
          0.5703, 0.5693, 0.6720, 0.6156, 0.4893, 0.3500, 0.1665, 0.5290,
          0.5409, 0.5089]],

        [[0.4988, 0.4683, 0.4982, 0.5202, 0.4895, 0.5097, 0.4798, 0.5433,
          0.4934, 0.3952, 0.4510, 0.4762, 0.4450, 0.4498, 0.4992, 0.5374,
          0.4759, 0.4853, 0.4968, 0.4878, 0.5454, 0.5031, 0.5336, 0.4991,
          0.4892, 0.5303]],

        [[0.4848, 0.4907, 0.4758, 0.4641, 0.4902, 0.5321, 0.5268, 0.4712,
          0.5088, 0.5583, 0.4120, 0.4575, 0.4140, 0.4422, 0.4863, 0.4762,
          0.4752, 0.4268, 0.4947, 0.5559, 0.5061, 0.4637, 0.4903, 0.4751,
          0.4729, 0.4853]],

        [[0.5059, 0.5526, 0.4515, 0.5014, 0.5445, 0.5274, 0.4267, 0.4368,
          0.4556, 0.4447, 0.5065, 0.5927, 0.5003, 0.5773, 0.5009, 0.4859,
          0.5056, 0.5319, 0.4666, 0.4745, 0.4980, 0.4886, 0.5245, 0.5242,
          0.5260, 0.4557]

# Figuring out the environment

In [32]:
import gym
import gym_wordle
wordle = gym.make("Wordle-v0")

FileNotFoundError: [Errno 2] No such file or directory: '/Users/shanky/miniconda3/envs/tf/lib/python3.8/site-packages/gym_wordle/dictionary/guess_list.npy'

In [28]:
obs1 = wordle.reset()
print(obs1)
obs = wordle.step([2, 17, 0, 19, 4])
wordle.render()
print(obs)
wordle.hidden_word

[[0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]]


AssertionError: Invalid word!

In [None]:
wordle.step([17, 4, 1, 0, 1])

AssertionError: Invalid word!

In [None]:
import numpy as np
rnn = ActorNetwork(5, 32, 26, 5)
o = np.array([obs[0]['board']])
o = torch.from_numpy(o).to(torch.float32)
print(o.shape)
print(rnn(o))

torch.Size([1, 6, 5])
tensor([[[0.5784, 0.5548, 0.4521, 0.5546, 0.4620, 0.3900, 0.4581, 0.4895,
          0.3628, 0.6697, 0.5521, 0.5222, 0.5605, 0.5359, 0.5742, 0.4070,
          0.5257, 0.4200, 0.4461, 0.5238, 0.3695, 0.4998, 0.4732, 0.4871,
          0.5293, 0.4226]],

        [[0.5291, 0.4355, 0.4914, 0.5404, 0.5746, 0.4720, 0.4706, 0.4990,
          0.5314, 0.5156, 0.5034, 0.4374, 0.5129, 0.4606, 0.5518, 0.4645,
          0.4489, 0.5344, 0.5514, 0.4465, 0.4661, 0.4793, 0.4780, 0.5296,
          0.4005, 0.5002]],

        [[0.4930, 0.5260, 0.4909, 0.3779, 0.4686, 0.4809, 0.5326, 0.5496,
          0.5125, 0.5935, 0.4286, 0.5348, 0.4980, 0.4997, 0.4767, 0.5332,
          0.5729, 0.4449, 0.5696, 0.4686, 0.5495, 0.5180, 0.5247, 0.4759,
          0.4945, 0.4484]],

        [[0.5176, 0.4395, 0.4981, 0.4888, 0.5461, 0.5065, 0.5027, 0.4784,
          0.4654, 0.5807, 0.5061, 0.5163, 0.5095, 0.4966, 0.5762, 0.4662,
          0.4296, 0.5343, 0.4741, 0.5453, 0.5063, 0.4983, 0.4748, 0.4978,
   

In [None]:
prob = Categorical(rnn(o))
prob.sample()

tensor([[ 7],
        [ 8],
        [12],
        [ 4],
        [20]])

# Competing A2C solvers

In [None]:
def train(replay, q_val):
    vals = torch.vstack(replay.vals)
    
    rewards_tensor = torch.tensor(np.asarray(replay.rewards, dtype=np.float32))
    is_terminal_tensor = torch.tensor(np.asarray(replay.dones, dtype=np.int32))
        
    q_vals_tensor = rewards_tensor + discount_factor * q_val * (1 - is_terminal_tensor)
        
    advantage = q_vals_tensor - vals
    
    critic_loss = (advantage ** 2).mean()
    # critic_loss.requires_grad = True
    adam_critic.zero_grad()
    critic_loss.backward(retain_graph=True)
    adam_critic.step()
    
    
    log_probabs_tensor = torch.vstack(replay.log_probs)
    actor_loss = (-log_probabs_tensor*advantage.detach()).mean()
    # actor_loss.requires_grad = True
    adam_actor.zero_grad()
    actor_loss.backward()
    adam_actor.step()

In [None]:
class Advantage_ActorCritic():
    def __init__(self, world: gym.Env, policy_net: ActorNetwork, critic_net: CriticNet,
                 encoder, policy_alpha, critic_alpha, gamma, max_reward):
        # environment info
        self.world = world
        self.encoder = encoder
        self.max_reward = max_reward

        # actor and critic
        self.actor = policy_net
        self.critic = critic_net
        self.error_buffer = list()
        self.policy_optimizer = torch.optim.Adam(policy_net.network.parameters(), lr=policy_alpha)
        self.v_optimizer = torch.optim.Adam(critic_net.v_network.parameters(), lr=critic_alpha)

        # training info
        self.gamma = gamma
        self.episodes = 0

    def train(self, iterations):
        converged = False
        rewards = list()
        recents = torch.zeros(10)
        i = 0
        while not converged:
            r = self.episode()
            rewards.append(r)

            if i < 10:
                recents[i] = r
                i += 1

            else:
                recents.roll(-1, 0)
                recents[9] = r
            
            # convergence check
            if len(rewards) > 10 and recents.mean() >= self.max_reward:
                converged = True
            converged = True if iterations == self.episodes else False
        
        plt.plot(rewards)
        plt.show()


    def episode(self, training=True):
        done = False
        state = self.world.reset().copy()
        episode_reward = 0
        self.episodes += 1
        while not done:
            # take on policy action
            encoded = self.encoder(state)
            net_out = self.actor(encoded)

            dist = Categorical(net_out)
            action = dist.sample()
            try:
                state_prime, reward, done, _ = self.world.step(action)
            except AssertionError:
                reward = -.1
                state_prime = state.copy()

            # fill buffer
            self.error_buffer.append((state, state_prime, reward, dist.log_prob(action)))

            # prepare for next iteration
            episode_reward += reward
            state = state_prime.copy()
            if training:
                self.__net_update()

        return episode_reward

    def td_error(self, value, value_prime, reward):
        def loss_fn():
            print(value.shape, value_prime.shape) 
            return reward + self.gamma * value_prime - value
        return loss_fn()

    def policy_error(self, prob, error):
        def loss_fn():
            print(prob)
            print(error)
            return -torch.log(prob) * error
        return loss_fn()

    def __net_update(self):
        # calculate error for policy and critic
        self.critic.v_network.zero_grad()
        self.actor.network.zero_grad()

        state, state_prime, reward, prob = self.error_buffer.pop(0)
        print(self.encoder(state))
        value = self.critic(self.encoder(state))
        value_prime = self.critic(self.encoder(state_prime))

        td_error = self.td_error(value, value_prime, reward)
        policy_error = self.policy_error(prob, td_error) 

        # backprop
        if self.episodes > 100:
            td_error.backward()
        else:
            policy_error.backward()
            self.policy_optimizer.step()
        self.v_optimizer.step()

# Solving the Environment

In [None]:
critic = CriticNet(30)
actor = ActorNetwork(30, 32, 26)
agent = Advantage_ActorCritic(wordle, actor, critic, 
lambda x: torch.tensor([x['board']]).to(torch.float32).reshape(1, 1,30), 1e-3, 1e-3, 0.9, 1)

In [None]:
agent.episode()

tensor([[[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
          -1., -1.]]])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
tensor([[-3.1386],
        [-3.2099],
        [-3.1924],
        [-3.2840],
        [-3.2970]], grad_fn=<SqueezeBackward1>)
tensor([[[-0.0904]]], grad_fn=<SubBackward0>)


RuntimeError: grad can be implicitly created only for scalar outputs