In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["DATA_PATH"] = "../assets/"

In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim

from tqdm.auto import tqdm

from src.game.wrapped_flappy_bird import GameState
from src.models.DoubleDQN import QualityEstimator, policy, ConvQualityEstimator
from src.data.replay_memory import ReplayMemory
from src.pipelines.utils import get_state, get_image_state
from src.pipelines.train import optimize_model_double_dqn
from src.data.shemas import ConfigData

from collections import deque

from torch.utils.tensorboard import SummaryWriter

pygame 2.4.0 (SDL 2.26.4, Python 3.10.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
writer = SummaryWriter()

In [4]:
config_dict = {
    "sessions_num": 10000,
    "gamma": 0.95,
    "lr": 1e-6,
    "lr_decay": 0.997,
    "state_dim": 5,
    "action_dim": 2,
    "hid_channel": [32, 32, 64],
    "hid_dim": [64, 512],
    "eps_init": 1,
    "eps_last": 1e-4,
    "eps_max_iters": 1000,  
    "temperature": 1,
    "batch_size":  32,
    "grad_clip": 100,
    "memory_size": 10000,
    "model_load_path": None, #"../models_backup/",
    "model_version": "_v10",
    "model_save_path": "../models/",
    "model_swap_time": 10,
    "max_session_score": 150,
    "player_flap_acc": -9,
    "dropout": 0.2,
    "img_num_in_state": 4,
    "device": DEVICE
}

cfg = ConfigData(
    **config_dict
)

In [5]:
action_terminal = GameState()
curr_flap_acc = cfg.player_flap_acc
action_terminal.playerFlapAcc = curr_flap_acc

if cfg.hid_channel is not None:
    model_qa = ConvQualityEstimator(
        in_channels=cfg.img_num_in_state,
        action_dim=cfg.action_dim,
        hid_channel=cfg.hid_channel,
        hid_dims=cfg.hid_dim,
        dropout=cfg.dropout
    ).to(DEVICE)

    model_qb = ConvQualityEstimator(
        in_channels=cfg.img_num_in_state,
        action_dim=cfg.action_dim,
        hid_channel=cfg.hid_channel,
        hid_dims=cfg.hid_dim,
        dropout=cfg.dropout
    ).to(DEVICE)
else:
    model_qa = QualityEstimator(
        cfg.state_dim,
        cfg.action_dim,
        cfg.hid_dim,
        cfg.dropout
    ).to(DEVICE)

    model_qb = QualityEstimator(
        cfg.state_dim,
        cfg.action_dim,
        cfg.hid_dim,
        cfg.dropout
    ).to(DEVICE)

if cfg.model_load_path is not None:
    model_qa = torch.load(cfg.model_load_path+"model_qa"+cfg.model_version).to(DEVICE)
    model_qb = torch.load(cfg.model_load_path+"model_qb"+cfg.model_version).to(DEVICE)

optimizer_qa = optim.AdamW(model_qa.parameters(), lr=cfg.lr)
optimizer_qb = optim.AdamW(model_qb.parameters(), lr=cfg.lr)

scheduler_qa = torch.optim.lr_scheduler.LambdaLR(optimizer_qa, lr_lambda=lambda epoch: cfg.lr_decay ** epoch)
scheduler_qb = torch.optim.lr_scheduler.LambdaLR(optimizer_qb, lr_lambda=lambda epoch: cfg.lr_decay ** epoch)

In [6]:
def train_step(
    model,
    target_model,
    state,
    memory,
    optimizer,
    sheduler,
    eps_update_iter,
    cfg: ConfigData
):  
    tensor_state = torch.cat(list(state)).unsqueeze(0).to(cfg.device)
    with torch.no_grad():
        action, _ = policy(
            model,
            tensor_state,
            epsilon=max((cfg.eps_init - cfg.eps_init / cfg.eps_max_iter * eps_update_iter), cfg.eps_last)
        )            
    _, reward, has_failed = action_terminal.frame_step(action)
    reward = torch.tensor(reward, device=DEVICE).view(1, 1)
    
    if cfg.hid_channel is not None:
        state.append(get_image_state(action_terminal))
        next_tensor_state = torch.cat(list(state)).unsqueeze(0)
        memory.push(tensor_state.to("cpu"), action, next_tensor_state, reward)
    else:
        next_state = get_state(action_terminal)
        memory.push(state, action, next_state, reward)
        state = next_state
    
    loss_value = optimize_model_double_dqn(model, target_model, memory, optimizer, cfg)
    sheduler.step()

    return reward, has_failed, loss_value

In [7]:
total_rewards = []

eps_update_iter_a = 0
eps_update_iter_b = 0
memory = ReplayMemory(cfg.memory_size)

iteration = 0

for session_idx in tqdm(range(cfg.sessions_num)):
    
    # Initialize new session
    has_failed = False
    max_score = 0
    input_action = 0
    total_reward = 0.0
    mean_loss = 0.0
    session_len = 0
    action_terminal.playerFlapAcc = curr_flap_acc
    
    action_terminal.frame_step(0)
    
    if cfg.hid_channel is not None:
        state = deque([], maxlen=cfg.img_num_in_state)
        for _ in range(cfg.img_num_in_state):
            state.append(get_image_state(action_terminal))
    else:
        state = get_state(action_terminal)
    
    # Session run
    while not has_failed:
        if iteration % (2 * cfg.model_swap_time) < cfg.model_swap_time:
            reward, has_failed, loss_value = train_step(
                model_qa,
                model_qb,
                state,
                memory,
                optimizer_qa,
                scheduler_qa,
                eps_update_iter_a,
                cfg
            )
            
        else:
            reward, has_failed, loss_value = train_step(
                model_qb,
                model_qa,
                state,
                memory,
                optimizer_qb,
                scheduler_qb,
                eps_update_iter_b,
                cfg
            )
        iteration += 1
        
        if loss_value is not None:
            mean_loss += loss_value
            session_len += 1
            

        total_reward += reward
        max_score = max(max_score, action_terminal.score)
        if max_score >= cfg.max_session_score:
            break
    
    # Post session updates
    if (session_idx + 1) % 200 == 0:
        torch.save(model_qa, cfg.model_save_path+"model_qa")
        torch.save(model_qb, cfg.model_save_path+"model_qb")

    if (session_idx + 1) % 2000 == 0 and curr_flap_acc > -9:
        curr_flap_acc -= 1
        action_terminal.playerFlapAcc = curr_flap_acc
        eps_update_iter_a = 0
        eps_update_iter_b = 0  
            
    eps_update_iter_a += 1
    eps_update_iter_b += 1
    
    total_rewards += [total_reward]   
    if session_len > 0: 
        mean_loss /= session_len
    
    writer.add_scalar(f"Total reward {cfg=}", total_reward, session_idx)
    writer.add_scalar(f"Max score {cfg=}", max_score, session_idx)
    writer.add_scalar(f"Mean session loss {cfg=}", mean_loss, session_idx)
    

  0%|          | 0/10000 [00:00<?, ?it/s]



KeyboardInterrupt: 