# Network Creation

In [None]:
!pip install torch ./Gym-Wordle

In [None]:
import torch
import matplotlib.pyplot as plt
from torch import normal

In [None]:
from torch import nn

class ActorNetwork(nn.Module):
    """LSTM RNN for generating words for wordle solver"""
    def __init__(self, input_size, output_size=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, output_size),
            nn.Softmax(0)
        )
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        y = self.network(x)
        print(y)
        return y

class CriticNet(nn.Module):
    """Network representing the critic"""
    def __init__(self, in_shape):
        super().__init__()
        self.v_network = nn.Sequential(
            nn.Linear(in_shape, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
            )

    def forward(self, x):
        return self.v_network(x)

# Figuring out the environment

In [None]:
import gym
import gym_wordle
wordle = gym.make('Wordle-v0')

In [None]:
obs = wordle.reset()
mu, sigma = 1.0, 2.0
action = normal(mu, sigma, size=(1,), generator=None, out=None)
action = action.clip(0, wordle.action_space.n)
print(action.round().item())
wordle.step(int(action.round().item()))
wordle.render()
print(obs)

# Competing A2C solvers

In [None]:
def train(replay, q_val):
    vals = torch.vstack(replay.vals)
    
    rewards_tensor = torch.tensor(np.asarray(replay.rewards, dtype=np.float32))
    is_terminal_tensor = torch.tensor(np.asarray(replay.dones, dtype=np.int32))
        
    q_vals_tensor = rewards_tensor + discount_factor * q_val * (1 - is_terminal_tensor)
        
    advantage = q_vals_tensor - vals
    
    critic_loss = (advantage ** 2).mean()
    # critic_loss.requires_grad = True
    adam_critic.zero_grad()
    critic_loss.backward(retain_graph=True)
    adam_critic.step()
    
    
    log_probabs_tensor = torch.vstack(replay.log_probs)
    actor_loss = (-log_probabs_tensor*advantage.detach()).mean()
    # actor_loss.requires_grad = True
    adam_actor.zero_grad()
    actor_loss.backward()
    adam_actor.step()

In [None]:
class Advantage_ActorCritic():
    def __init__(self, world: gym.Env, policy_net: ActorNetwork, critic_net: CriticNet,
                 encoder, policy_alpha, critic_alpha, gamma, max_reward):
        # environment info
        self.world = world
        self.encoder = encoder
        self.max_reward = max_reward

        # actor and critic
        self.actor = policy_net
        self.critic = critic_net
        self.error_buffer = list()
        self.policy_optimizer = torch.optim.Adam(policy_net.network.parameters(), lr=policy_alpha)
        self.v_optimizer = torch.optim.Adam(critic_net.v_network.parameters(), lr=critic_alpha)

        # training info
        self.gamma = gamma
        self.episodes = 0

    def train(self, iterations):
        converged = False
        rewards = list()
        recents = torch.zeros(10)
        i = 0
        while not converged:
            r = self.episode()
            rewards.append(r)

            if i < 10:
                recents[i] = r
                i += 1

            else:
                recents.roll(-1, 0)
                recents[9] = r
            
            # convergence check
            if len(rewards) > 10 and recents.mean() >= self.max_reward:
                converged = True
            converged = True if iterations == self.episodes else False
        
        plt.plot(rewards)
        plt.show()


    def episode(self, training=True):
        done = False
        state = self.world.reset().copy()
        episode_reward = 0
        self.episodes += 1
        while not done:
            # take on policy action
            encoded = self.encoder(state)
            dist = self.actor(encoded)

            action = normal(dist[0], dist[1])
            action.clip(0, wordle.action_space.n)
            action = int(action.clip(0, self.world.action_space.n).round().item())
            try:
                state_prime, reward, done, _ = self.world.step(action)
            except AssertionError:
                reward = -.1
                state_prime = state.copy()

            # fill buffer
            self.error_buffer.append((state, state_prime, reward, dist))

            # prepare for next iteration
            episode_reward += reward
            state = state_prime.copy()
            if training:
                self.__net_update()

        return episode_reward

    # Critic Loss
    def td_error(self, value, value_prime, reward):
        def loss_fn():
            print(value.shape, value_prime.shape) 
            return reward + self.gamma * value_prime - value
        return loss_fn()

    # Actor error
    def policy_error(self, prob, error):
        # def loss_fn():
        print(prob.shape)
        print(error.shape)
        log_probabs_tensor = prob.reshape(-1)
        print(log_probabs_tensor.shape)
        return (-log_probabs_tensor * error).mean()
        # return loss_fn()

    def __net_update(self):
        # calculate error for policy and critic
        self.critic.v_network.zero_grad()
        self.actor.network.zero_grad()

        state, state_prime, reward, prob = self.error_buffer.pop(0)
        print(f'Encoder -> {self.encoder(state)}')
        value = self.critic(self.encoder(state))
        value_prime = self.critic(self.encoder(state_prime))

        td_error = self.td_error(value, value_prime, reward)
        print(f'TD Error: {td_error}')
        print(f'Prob: {prob}')


        policy_error = self.policy_error(prob, td_error) 
        print(f'P Error: {policy_error.reshape(-1)}')

        # backprop
        if self.episodes > 100:
            td_error.backward()
        else:
            policy_error.backward()
            self.policy_optimizer.step()
        self.v_optimizer.step()

# Solving the Environment

In [None]:
critic = CriticNet(60)
actor = ActorNetwork(60, 2)
agent = Advantage_ActorCritic(wordle, actor, critic, 
lambda x: torch.tensor(x.flatten()).to(torch.float32), 1e-3, 1e-3, 0.9, 1)

In [None]:
agent.episode()