In [13]:
import gymnasium as gym
import ale_py
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

gym.register_envs(ale_py)

class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)

env = gym.make('ALE/DonkeyKong-v5')
obs_size = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
n_actions = env.action_space.n
print(obs_size, n_actions)

100800 18


In [11]:
obs, info = env.reset()
print(obs.shape)

(210, 160, 3)


In [23]:
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70

net = Net(obs_size, HIDDEN_SIZE, n_actions)
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)

In [29]:
from collections import namedtuple
import numpy as np

Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()[0]
    sm = nn.Softmax(dim=1)
    
    while True:
        obs_v = torch.FloatTensor(np.array([obs])).view(1, -1)
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, terminated, truncated, _ = env.step(action)
        is_done = terminated or truncated
        episode_reward += reward
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)
        
        if is_done:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()[0]
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

a = iterate_batches(env, net, 2)
next(iter(a))

KeyboardInterrupt: 

In [None]:
def filter_batch(batch, percentile):
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))

    train_obs = []
    train_act = []
    for reward, steps in batch:
        if reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, steps))
        train_act.extend(map(lambda step: step.action, steps))

    train_obs_v = torch.FloatTensor(np.array(train_obs))
    train_act_v = torch.LongTensor(np.array(train_act))
    return train_obs_v, train_act_v, reward_bound, reward_mean




In [None]:

env.close()
