In [1]:
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tianshou as ts
from pettingzoo.atari import pong_v3
from tianshou.env import DummyVectorEnv, PettingZooEnv

In [3]:
# task = "ALE/Pong-v5"


train_envs = ts.env.DummyVectorEnv([lambda: ts.env.PettingZooEnv(pong_v3.env(num_players=2))  for _ in range(8)])
test_envs = ts.env.DummyVectorEnv([lambda: ts.env.PettingZooEnv(pong_v3.env(num_players=2)) for _ in range(100)])

In [3]:
env = gym.make(task)

In [4]:
import torch, numpy as np
from torch import nn

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state

state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)


In [5]:
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)

In [6]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(total_size=20000, buffer_num=64))
test_collector = ts.data.Collector(policy, test_envs)

In [7]:
result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=10, step_per_epoch=1000, step_per_collect=10,
    episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
print(f'Finished training! Use {result["duration"]}')



: 

: 