In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["DATA_PATH"] = "../assets/"

In [2]:
import torch
import torch.optim as optim

from tqdm.auto import tqdm

from src.game.wrapped_flappy_bird import GameState
from src.models.DoubleDQN import QualityEstimator, policy
from src.data.replay_memory import ReplayMemory
from src.pipelines.utils import get_state
from src.pipelines.train import optimize_model

from torch.utils.tensorboard import SummaryWriter

pygame 2.4.0 (SDL 2.26.4, Python 3.10.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
writer = SummaryWriter()

In [4]:
config = {
    "sessions_num": 10000,
    "gamma": 0.99,
    "lr": 1e-4,
    "lr_decay": 0.998,
    "state_dim": 4,
    "action_dim": 2,
    "hid_dim": [64, 128, 128, 64],
    "eps_init": 0.05,
    "eps_last": 1e-4,
    "eps_max_iters": 250,  
    "temperature": 1,
    "batch_size": 300,
    "grad_clip": 50,
    "memory_size": 100000,
    "model_load_path": "../models_backup/",
    "model_version": "_v04",
    "model_save_path": "../models/",
    "model_swap_time": 10,
    "max_session_score": 150
}

MODEL_SAVE_PATH = config["model_save_path"]
MEMORY_SIZE = config["memory_size"]
GRAD_CLIP = config["grad_clip"]
BATCH_SIZE = config["batch_size"]
LR = config["lr"]
LR_DECAY = config["lr_decay"]
GAMMA = config["gamma"]
EPS_INIT = config["eps_init"]
EPS_LAST = config["eps_last"]
EPS_MAX_ITERS = config["eps_max_iters"]
TEMP = config["temperature"]
SESSIONS_NUM = config["sessions_num"]
MODEL_SWAP_TIME = config["model_swap_time"]
MAX_SESSION_SCORE = config["max_session_score"]

In [5]:
action_terminal = GameState()
action_terminal.playerFlapAcc = -7

model_qa = QualityEstimator(
    config["state_dim"],
    config["action_dim"],
    config["hid_dim"]
).to(DEVICE)

model_qb = QualityEstimator(
    config["state_dim"],
    config["action_dim"],
    config["hid_dim"]
).to(DEVICE)

if config["model_load_path"] is not None:
    model_qa = torch.load(config["model_load_path"]+"model_qa"+config["model_version"])
    model_qb = torch.load(config["model_load_path"]+"model_qb"+config["model_version"])

optimizer_qa = optim.Adam(model_qa.parameters(), lr=LR)
optimizer_qb = optim.Adam(model_qb.parameters(), lr=LR)

scheduler_qa = torch.optim.lr_scheduler.LambdaLR(optimizer_qa, lr_lambda=lambda epoch: LR_DECAY ** epoch)
scheduler_qb = torch.optim.lr_scheduler.LambdaLR(optimizer_qb, lr_lambda=lambda epoch: LR_DECAY ** epoch)

In [6]:
total_rewards = []
memory = ReplayMemory(MEMORY_SIZE)

eps_update_iter_a = 0
eps_update_iter_b = 0

for session_idx in tqdm(range(SESSIONS_NUM)):
    
    # Initialize new session
    is_failed = False
    max_score = 0
    input_action = 0
    total_reward = 0.0
    state = get_state(action_terminal)
    
    # Session run
    while not is_failed:
        if session_idx % (2 * MODEL_SWAP_TIME) < MODEL_SWAP_TIME:
            action, predicted_reward = policy(
                model_qa,
                state,
                epsilon=max((EPS_INIT - EPS_INIT / EPS_MAX_ITERS * eps_update_iter_a), EPS_LAST)
            )            
            _, reward, is_failed = action_terminal.frame_step(action)
            reward = torch.tensor([reward], device=DEVICE)
            next_state = get_state(action_terminal)
            
            memory.push(state, action, next_state, reward)
            state = next_state

            optimize_model(model_qa, model_qb, memory, optimizer_qa, BATCH_SIZE, GAMMA, GRAD_CLIP)
            scheduler_qa.step()
            eps_update_iter_a += 1
        else:
            action, predicted_reward = policy(
                model_qb,
                state,
                epsilon=max((EPS_INIT - EPS_INIT / EPS_MAX_ITERS * eps_update_iter_b), EPS_LAST)
            )
            _, reward, is_failed = action_terminal.frame_step(action)
            reward = torch.tensor([reward], device=DEVICE)
            next_state = get_state(action_terminal)
            
            memory.push(state, action, next_state, reward)
            state = next_state

            optimize_model(model_qb, model_qa, memory, optimizer_qb, BATCH_SIZE, GAMMA, GRAD_CLIP)
            scheduler_qb.step()
            eps_update_iter_b += 1

        total_reward += reward
        max_score = max(max_score, action_terminal.score)
        if max_score >= MAX_SESSION_SCORE:
            break
    
    # Post session updates
    if (session_idx + 1) % 1000 == 0:
        torch.save(model_qa, MODEL_SAVE_PATH+"model_qa")
        torch.save(model_qb, MODEL_SAVE_PATH+"model_qb")

    if (session_idx + 1) % 500 == 0 and action_terminal.playerFlapAcc > -9:
        action_terminal.playerFlapAcc -= 1
        eps_update_iter_a = 0
        eps_update_iter_b = 0
        print(action_terminal.playerFlapAcc)        
        
    total_rewards += [total_reward]    
    writer.add_scalar(f"Total reward {config=}", total_reward, session_idx)
    writer.add_scalar(f"Max score {config=}", max_score, session_idx)
    

  0%|          | 0/10000 [00:00<?, ?it/s]



KeyboardInterrupt: 