# Flappy sim environment

In [None]:
from coderbot_sim.flappy.widget import FlappySim

env = FlappySim()
env.render()

# This is an example of a hard coded solution
actions = [a for a in [0, 0, 1, 0, 0, 1, 0, 0, 0] for _ in range(10)]

for i in range(4):
    for t in range(len(actions)):
        action = actions[t]
        state = await env.step(action)
        # print(f"t={t:03d}", f"state={state}")

In [None]:
from coderbot_sim.flappy.tk import FlappyTkFrontend

env = FlappyTkFrontend()
env.render()

# This is an example of a hard coded solution
actions = [a for a in [0, 0, 1, 0, 0, 1, 0, 0, 0] for _ in range(10)]

for i in range(4):
    for t in range(len(actions)):
        action = actions[t]
        state = await env.step(action)
        # print(f"t={t:03d}", f"state={state}")

In [None]:
%pip install torch

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from coderbot_sim.flappy import FlappyEnv

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        logits = self.fc2(x)
        return torch.softmax(logits, dim=-1)
    

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    


def make_obs(state):
    if len(state["pipes"]) == 0:
        pipe_x = 0
        pipe_bottom = 0
    else:
        pipe = state["pipes"][0]
        upper, lower = pipe[0], pipe[1]
        pipe_x = upper["x"]
        pipe_bottom = lower["y"]

    return np.array([state["bird_y"], state["bird_vel"], pipe_x, pipe_bottom], dtype=np.float32)


env = FlappyEnv()

actor = Actor()
critic = Critic()
opt_actor = optim.Adam(actor.parameters(), lr=1e-4)
opt_critic = optim.Adam(critic.parameters(), lr=1e-4)

gamma = 0.99

def train_a2c(env, episodes=2000):
    for ep in range(episodes):
        state = make_obs(env.reset())
        log_probs = []
        values = []
        rewards = []

        while True:
            s = torch.tensor(state, dtype=torch.float32)
            probs = actor(s)
            dist = Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            value = critic(s)

            # step the env
            next_state_dict = env.step(action.item())
            next_state = make_obs(next_state_dict)

            # shaped reward
            reward = 0.1
            if next_state_dict["done"]:
                reward = -5
            if next_state_dict["score"] > env.score:
                reward = 10

            log_probs.append(log_prob)
            values.append(value)
            rewards.append(reward)

            state = next_state
            # print(state)

            if next_state_dict["done"]:
                break

        R = 0
        returns = []
        for r in reversed(rewards):
            R = r + gamma * R
            returns.append(R)
        returns.reverse()

        returns = torch.tensor(returns, dtype=torch.float32)
        values = torch.cat(values).squeeze()
        advantage = returns - values

        actor_loss = -(torch.stack(log_probs) * advantage.detach()).mean()
        critic_loss = advantage.pow(2).mean()

        opt_actor.zero_grad()
        actor_loss.backward()
        opt_actor.step()

        opt_critic.zero_grad()
        critic_loss.backward()
        opt_critic.step()

        print(f"Episode {ep}, Reward={sum(rewards):.2f}")


await train_a2c(env)