NES Environment

In [1]:
!pip install gym-super-mario-bros==7.3.0
!pip install tqdm



In [2]:
import pickle
import os
from tqdm import tqdm

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

import gym
from gym.wrappers import FrameStack

import numpy as np
import torch
from torchvision import transforms as T
import matplotlib.pyplot as plt

In [3]:
%matplotlib inline

In [4]:
env_test = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
print(env_test.observation_space.shape)
print(env_test.action_space.n)

(240, 256, 3)
256


In [5]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        """Return only every `skip`-th frame"""
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

In [6]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)
        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = gym.spaces.Box(low=0, high=255, 
                                         shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation

In [7]:
class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = gym.spaces.Box(low=0, high=255, 
                                         shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation

In [8]:
def create_mario_env(env_name):
    env = gym_super_mario_bros.make(env_name)
    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    env = FrameStack(env, num_stack=4)
    return JoypadSpace(env, SIMPLE_MOVEMENT)

In [None]:
import numpy as np

class ReplayBuffer:
    def __init__(self, state_shape, action_space, batch_size=32, max_size=10000,
                 load=False, path=None):
        self.path = path + 'buffer/'
        self.max_size = max_size
        self.batch_size = batch_size

        if load:
            self.load()
        else:
            self.next = 0
            self.size = 0

            self.states = torch.empty((max_size, *state_shape))
            self.actions = torch.empty((max_size, 1), dtype=torch.int64)
            self.rewards = torch.empty((max_size, 1))
            self.states_p = torch.empty((max_size, *state_shape))
            self.is_terminals = torch.empty((max_size, 1), dtype=torch.float)


    def __len__(self): return self.size
    

    def store(self, state, action, reward, state_p, is_terminal):
        state = state.__array__()
        state_p = state_p.__array__()

        self.states[self.next] = torch.tensor(state)
        self.actions[self.next] = action
        self.rewards[self.next] = reward
        self.states_p[self.next] = torch.tensor(state_p)
        self.is_terminals[self.next] = is_terminal

        self.size = min(self.size + 1, self.max_size)
        self.next = (self.next + 1) % self.max_size


    def sample(self):
        indices = np.random.choice(self.size, size=self.batch_size, 
                                   replace=False)
        return self.states[indices], \
            self.actions[indices], \
            self.rewards[indices], \
            self.states_p[indices], \
            self.is_terminals[indices]


    def clear(self):
        self.next = 0
        self.size = 0
        self.states = torch.empty_like(self.states)
        self.actions = torch.empty_like(self.actions)
        self.rewards = torch.empty_like(self.rewards)
        self.states_p = torch.empty_like(self.states_p)
        self.is_terminals = torch.empty_like(self.is_terminals)


    def load(self):
        with open(self.path + "next.pkl", 'rb') as f:
            self.next = pickle.load(f)
        with open(self.path + "size.pkl", 'rb') as f:
            self.size = pickle.load(f)
        self.states = torch.load(self.path + "states.pt")
        self.actions = torch.load(self.path + "actions.pt")
        self.rewards = torch.load(self.path + "rewards.pt")
        self.states_p = torch.load(self.path + "states_p.pt")
        self.is_terminals = torch.load(self.path + "is_terminals.pt")


    def save(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        with open(self.path + "next.pkl", "wb") as f:
            pickle.dump(self.next, f)
        with open(self.path + "size.pkl", "wb") as f:
            pickle.dump(self.size, f)
        torch.save(self.states, self.path + "states.pt")
        torch.save(self.actions, self.path + "actions.pt")
        torch.save(self.rewards, self.path + "rewards.pt")
        torch.save(self.states_p, self.path + "states_p.pt")
        torch.save(self.is_terminals, self.path + "is_terminals.pt")

In [None]:
class QNetwork(torch.nn.Module):
    def __init__(self, input_shape, actions_size, 
                optimizer=torch.optim.Adam, learning_rate=0.00025):
        super().__init__()
        self.personalized = torch.nn.Sequential(
            torch.nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),
            torch.nn.ReLU(),
        )
        self.shared = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(3136, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, actions_size)
        )
        self.optimizer = optimizer(self.parameters(), lr=learning_rate)
        self.loss_fn = torch.nn.SmoothL1Loss()


    def format_(self, states):
        if not isinstance(states, torch.Tensor):
            states = torch.tensor(states, dtype=torch.float32)
        return states


    def forward(self, x):
        states = self.format_(x)
        out = self.personalized(states)
        out = self.shared(out)
        return out


    def update_netowrk(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
class Agent():
    def __init__(self, id, env_name, env_fn, Qnet=QNetwork, buffer=ReplayBuffer,
                 max_epsilon=1, min_epsilon=0.05, epsilon_decay=0.99, gamma=0.9,
                 target_update_rate=2000, min_buffer=100, 
                 load=False, path=None) -> None:
        self.id = id
        self.path = path + str(id) + "/"

        self.env = env_fn(env_name)
        self.env_fn = env_fn
        self.n_actions = self.env.action_space.n
        self.state_shape = self.env.observation_space.shape
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.min_buffer = min_buffer
        self.min_epsilon = min_epsilon
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.target_update_rate = target_update_rate
        self.buffer = buffer(self.state_shape, self.n_actions,
                             load=load, path=self.path)

        self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device)
        self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device)

        if load:
            self.load()
        else:
            self.update_target_network()
            self.epsilon = max_epsilon
            self.step_count = 0
            self.episode_count = 0
            self.rewards = []

    
    def load(self):
        with open(self.path + "step_count.pkl", 'rb') as f:
            self.step_count = pickle.load(f)
        with open(self.path + "episode_count.pkl", 'rb') as f:
            self.episode_count = pickle.load(f)
        with open(self.path + "rewards.pkl", 'rb') as f:
            self.rewards = pickle.load(f)
        with open(self.path + "epsilon.pkl", 'rb') as f:
            self.epsilon = pickle.load(f)
        self.online_net.load_state_dict(torch.load(self.path + "online_net.pt", 
                                                   map_location=torch.device(self.device)))
        self.target_net.load_state_dict(torch.load(self.path + "target_net.pt", 
                                                   map_location=torch.device(self.device)))

    def save(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        self.buffer.save()
        with open(self.path + "step_count.pkl", "wb") as f:
            pickle.dump(self.step_count, f)
        with open(self.path + "episode_count.pkl", "wb") as f:
            pickle.dump(self.episode_count, f)
        with open(self.path + "rewards.pkl", "wb") as f:
            pickle.dump(self.rewards, f)
        with open(self.path + "epsilon.pkl", "wb") as f:
            pickle.dump(self.epsilon, f)
        torch.save(self.online_net.state_dict(), self.path +  "online_net.pt")
        torch.save(self.target_net.state_dict(), self.path +  "target_net.pt")



    def train(self, n_episodes):
        for i in tqdm(range(n_episodes)):
            episode_reward = 0
            state = self.env.reset()

            while True:
                self.step_count += 1
                action = self.epsilonGreedyPolicy(state)
                state_p, reward, done, info = self.env.step(action)
                episode_reward += reward

                is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']
                is_failure = done and not is_truncated
                self.buffer.store(state, action, reward, state_p, float(is_failure))

                if len(self.buffer) >= self.min_buffer:
                    self.update()
                    if self.step_count % self.target_update_rate == 0:
                        self.update_target_network()

                state = state_p
                if done:
                    self.episode_count += 1
                    self.rewards.append(episode_reward)
                    break

        print("Agent-{} Episode {} Step {} score = {}, average score = {}"\
                .format(self.id, self.episode_count, self.step_count, self.rewards[-1], np.mean(self.rewards)))


    def get_score(self):
        # return np.mean(self.rewards[-5:])
        return 1


    def update(self):
        states, actions, rewards, states_p, is_terminals = self.buffer.sample()
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        states_p = states_p.to(self.device)
        is_terminals = is_terminals.to(self.device)

        td_estimate = self.online_net(states).gather(1, actions)

        actions_p = self.online_net(states).argmax(axis=1, keepdim=True)
        with torch.no_grad():
            q_states_p = self.target_net(states_p)
        q_state_p_action_p = q_states_p.gather(1, actions_p)
        td_target = rewards + (1-is_terminals) * self.gamma * q_state_p_action_p

        self.online_net.update_netowrk(td_estimate, td_target)
        self.update_epsilon()


    def update_epsilon(self):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon, self.min_epsilon)


    def update_target_network(self):
        self.target_net.load_state_dict(self.online_net.state_dict())


    def epsilonGreedyPolicy(self, state):
        if np.random.rand() < self.epsilon:
            action = np.random.randint(self.n_actions)
        else:
            state = state.__array__()
            state = torch.tensor(state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                action = self.online_net(state).argmax().item()
        return action

In [None]:
class Mario(Agent):
    def __init__(self, env_names, env_fn, Qnet=QNetwork, load=False, path=None) -> None:
        self.path = path + "global/"
        self.envs = []
        for name in env_names:
            self.envs.append(env_fn(name))
        self.n_actions = self.envs[0].action_space.n
        self.state_shape = self.envs[0].observation_space.shape
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device)
        self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device)

        if load:
            self.load()
        else:
            self.update_target_network()


    def load(self):
        self.online_net.load_state_dict(torch.load(self.path + "online_net.pt", 
                                                   map_location=torch.device(self.device)))
        self.target_net.load_state_dict(torch.load(self.path + "target_net.pt", 
                                                   map_location=torch.device(self.device)))


    def save(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        torch.save(self.online_net.state_dict(), self.path + "online_net.pt")
        torch.save(self.target_net.state_dict(), self.path + "target_net.pt")


    def get_score(self):
        # return np.mean(self.rewards[-5:])
        return 1


    def test(self):
        rewards = np.zeros(len(self.envs))
        for i in range(len(self.envs)):
            r = self.evaluate(i)
            rewards[i] = r
        return rewards


    def evaluate(self, i):
        rewards = 0
        state = self.envs[i].reset()
        while True:
            action = self.greedyPolicy(state)
            state_p, reward, done, _ = self.envs[i].step(action)
            rewards += reward
            if done:
                break
            state = state_p
        return rewards


    def greedyPolicy(self, state):
        with torch.no_grad():
            state = state.__array__()
            state = torch.tensor(state).unsqueeze(0).to(self.device)
            action = self.target_net(state).argmax().item()
        return action

In [None]:
class Federator:
    def __init__(self, env_fn, update_rate, path="./Mario/", load=False) -> None:
        self.path = path
        self.envs = [
                'SuperMarioBros-1-1-v0',
                'SuperMarioBros-1-2-v0',
                'SuperMarioBros-1-3-v0',
                'SuperMarioBros-1-4-v0'
        ]
        self.global_agent = Mario(self.envs, env_fn, load=load, path=self.path)

        self.update_rate = update_rate
        self.n_agents = 4
        self.agents = []
        for i in range(self.n_agents):
            agent = Agent(i, self.envs[i], env_fn, load=load, path=self.path)
            self.agents.append(agent)

        if load:
            self.load()
        else:
            self.set_local_networks()
            self.rewards = []


    def load(self):
        with open(self.path + "rewards.pkl", 'rb') as f:
            self.rewards = pickle.load(f)


    def save(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        with open(self.path + "rewards.pkl", "wb") as f:
            pickle.dump(self.rewards, f)
        self.global_agent.save()
        for agent in self.agents:
            agent.save()
        print("All Saved to " + self.path)

    def train(self, n_runs):
        rewards = np.zeros((n_runs, len(self.envs)))
        for i in range(n_runs):
            print("Iteration: {}".format(i+1))
            scores = []
            for agent in self.agents:
                agent.train(self.update_rate)
                scores.append(agent.get_score())
            self.aggregate_networks(scores)
            self.set_local_networks()
            rewards[i] = self.global_agent.test()
            print(rewards[i])
        self.save()


    def aggregate_networks(self, scores):
        sd_online = self.global_agent.online_net.state_dict()
        sd_target = self.global_agent.target_net.state_dict()

        online_dicts = []
        target_dicts = []
        for agent in self.agents:
            online_dicts.append(agent.online_net.state_dict())
            target_dicts.append(agent.target_net.state_dict())

        for key in sd_online:
            sd_online[key] = torch.zeros_like(sd_online[key])
            for i, dict in enumerate(online_dicts):
                sd_online[key] += scores[i] * dict[key]
            sd_online[key] /= sum(scores)

        for key in sd_target:
            sd_target[key] = torch.zeros_like(sd_target[key])
            for i, dict in enumerate(target_dicts):
                sd_target[key] += scores[i] * dict[key]
            sd_target[key] /= sum(scores)

        self.global_agent.online_net.load_state_dict(sd_online)
        self.global_agent.target_net.load_state_dict(sd_target)


    def set_local_networks(self):
        for agent in self.agents:
            agent.online_net.load_state_dict(
                self.global_agent.online_net.state_dict())
            agent.target_net.load_state_dict(
                self.global_agent.target_net.state_dict())

In [None]:
agent = Federator(create_mario_env, 200, load=True)
agent.train(5)

In [None]:
! cp -r ./Mario/ /content/drive/Shareddrives/Sam/