In [None]:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.monitor import Monitor
import numpy as np
import torch
from environment3 import LifeStyleEnv
import gymnasium as gym
from stable_baselines3 import DQN

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

def mask_fn(env: gym.Env) -> np.ndarray:
    unwrapped_env = env
    while hasattr(unwrapped_env, "env"):
        unwrapped_env = unwrapped_env.env
    return unwrapped_env.action_masks()

def make_env(is_eval: bool = False):
    env = LifeStyleEnv()
    env = Monitor(env)
    if not is_eval:
        check_env(env, warn=True) 
    return env

env = make_env()

eval_env = make_env(is_eval=True)

eval_callback = EvalCallback(
    eval_env,
    best_model_save_path="./logs/dqn/dqn_best_model",
    log_path="./logs/dqn/dqn_results",
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False
)

model = DQN(
    "MultiInputPolicy",  
    env,
    learning_rate=0.0005,
    batch_size=128,
    gamma=0.95,
    train_freq=4,
    gradient_steps=1, 
    target_update_interval=1000,
    exploration_fraction=0.3,
    verbose=1,
    device=device,
    policy_kwargs=dict(net_arch=[256, 256]),
    tensorboard_log="./logs/dqn/dqn_tensorboard/"
)

model.learn(
    total_timesteps=100000, 
    callback=eval_callback
)
