# Train Agent on Gymnasium Environment

In [None]:
import numpy as np
import torch
from src.agent.reinforce import ReinforceAgent
from src.trainer.trainer import Trainer
from src.environment.gymnasium import GymnasiumEnvironment, VisualGymnasiumEnvironment

%matplotlib inline
import os

%load_ext autoreload
%autoreload 2

In [None]:
env = GymnasiumEnvironment(name="ALE/Breakout-v5", obs_type="ram")
device = "cuda" if torch.cuda.is_available() else "cpu"
agent = ReinforceAgent(
    env=env,
    device=device,
    lr=5e-4,
    hidden_dims=[64],
    max_steps=500,
)
trainer = Trainer(env=env, agent=agent, device=device)
try:
    trainer.train(
        n_epochs=5000,
        eps_start=1.0,
        eps_end=0.05,
        eps_decay=0.995
    )
except KeyboardInterrupt as e:
    env.close()
    print("Environment closed")

In [None]:
env.close()

In [None]:
agent.save()