# Flappy sim environment

In [None]:
from coderbot_sim.flappy.widget import FlappySim

env = FlappySim()
env.render()

# This is an example of a hard coded solution
actions = [a for a in [0, 0, 1, 0, 0, 1, 0, 0, 0] * 2 for _ in range(10)]

for t in range(len(actions)):
    action = actions[t]
    state = await env.step(action)
    # print(f"t={t:03d}", f"state={state}")

In [None]:
from coderbot_sim.flappy.tk import FlappyTkFrontend

env = FlappyTkFrontend()
env.render()

# This is an example of a hard coded solution
actions = [a for a in [0, 0, 1, 0, 0, 1, 0, 0, 0] * 2 for _ in range(10)]

for t in range(len(actions)):
    action = actions[t]
    state = await env.step(action)
    # print(f"t={t:03d}", f"state={state}")

In [None]:
%pip install torch

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from coderbot_sim.flappy import FlappyEnv

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=-1)


class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

MAX_X = 800.0
MAX_Y = 600.0

def make_obs(state):
    # Next pipe
    if len(state["pipes"]) == 0:
        pipe_x = MAX_X
        pipe_bottom = MAX_Y/2
    else:
        upper, lower = state["pipes"][0]
        pipe_x = upper["x"]
        pipe_bottom = lower["y"]

    # Normalize features to [-1,1]
    bird_y = state["bird_y"] / MAX_Y
    bird_vel = state["bird_vel"] / 400.0
    pipe_x = pipe_x / MAX_X
    pipe_bottom = pipe_bottom / MAX_Y

    return np.array([bird_y, bird_vel, pipe_x, pipe_bottom], dtype=np.float32)

env = FlappyEnv()

actor = Actor()
critic = Critic()

opt_actor = optim.Adam(actor.parameters(), lr=1e-4)
opt_critic = optim.Adam(critic.parameters(), lr=1e-4)

gamma = 0.99

def train_a2c(env, episodes=5000):
    for ep in range(episodes):
        state_dict = env.reset()
        state = make_obs(state_dict)

        log_probs = []
        values = []
        rewards = []

        done = False
        t = 0

        while not done:
            t += 1
            s = torch.tensor(state, dtype=torch.float32)

            probs = actor(s)
            dist = Categorical(probs)
            action = dist.sample()

            value = critic(s)
            log_prob = dist.log_prob(action)

            next_state_dict = env.step(action.item())
            next_state = make_obs(next_state_dict)

            # Reward = time alive
            reward = 1.0

            log_probs.append(log_prob)
            values.append(value)
            rewards.append(reward)

            state = next_state
            done = next_state_dict["done"]

        R = 0
        returns = []
        for r in reversed(rewards):
            R = r + gamma * R
            returns.append(R)
        returns.reverse()

        returns = torch.tensor(returns, dtype=torch.float32)
        values = torch.stack(values).squeeze()

        advantage = returns - values.detach()

        actor_loss = -(torch.stack(log_probs) * advantage).mean()
        critic_loss = (returns - values).pow(2).mean()

        opt_actor.zero_grad()
        actor_loss.backward()
        opt_actor.step()

        opt_critic.zero_grad()
        critic_loss.backward()
        opt_critic.step()

        print(f"Episode {ep} | Time Alive = {len(rewards)} steps")


train_a2c(env)
