In [1]:
from dataloader import AtariDataset
import gym
import torch.nn as nn
import torch
import numpy as np
import random
import tqdm
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim import optimizer
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
import cv2

## SEEDING

In [2]:
def reseed(seed):
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
seed = 42
reseed(seed)


## LOAD DATA

In [3]:
dataloader = AtariDataset("atari_v1", 2)
observations, actions, rewards, next_observations, dones = dataloader.compile_data()

2
[1960, 1870]


## MAKE ENVIRONMENT

In [4]:
def make_env(env_id, seed=25):
    env = gym.make(env_id, obs_type='grayscale', render_mode='rgb_array', repeat_action_probability=0.15,frameskip=1)
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env
env = make_env("SpaceInvaders-v0", seed=seed)
print(env.action_space.n)
print(env.observation_space.shape)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

6
(210, 160)
Device:  cuda:0


In [5]:
def visualize(learner, env, video_name="test"):
    """Visualize a policy network for a given algorithm on a single episode

        Args:
            algorithm (PolicyGradient): Algorithm whose policy network will be rolled out for the episode. If
            no algorithm is passed in, a random policy will be visualized.
            video_name (str): Name for the mp4 file of the episode that will be saved (omit .mp4). Only used
            when running on local machine.
    """

    import cv2

    print("Visualizing")

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    video = cv2.VideoWriter(f"{video_name}.avi", fourcc, 24, (160,210), isColor = True)
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        action = learner.get_action(torch.Tensor([obs]).to(device))
        obs, reward, done, info = env.step(action)

        total_reward += reward

        if done:
            break

        im = env.render(mode='rgb_array')
        
        video.write(im)

    video.release()
    env.close()
    print(f"Video saved as {video_name}.avi")
    print("Reward: " + str(total_reward))

# TRAIN DQN (TEST)

In [6]:
from dqn import DQN
import dqn

INPUT_SHAPE = 210*160
ACTION_SIZE = env.action_space.n

dqn_learner = DQN(INPUT_SHAPE, ACTION_SIZE)

dqn.train(dqn_learner, env, observations=observations, actions=actions, rewards=rewards, next_observations=next_observations, dones=dones, save_path='models/dqn_test.pth')

## Train BC

In [7]:
from bc import SpaceInvLearner
import bc

learner = SpaceInvLearner(env)

bc.train(learner=learner, observations=observations, checkpoint_path="models/bc_learner.pth", actions=actions, num_epochs=100)

Training the learner
Training for 100 epochs


  1%|          | 1/100 [00:00<01:25,  1.16it/s]

Epoch 0, Loss: 0.35799575961034036


  2%|▏         | 2/100 [00:01<01:11,  1.36it/s]

Epoch 1, Loss: 0.14913281434191192


  3%|▎         | 3/100 [00:02<01:06,  1.47it/s]

Epoch 2, Loss: 0.13065702621260075


  4%|▍         | 4/100 [00:02<01:03,  1.52it/s]

Epoch 3, Loss: 0.12364452182898229


  5%|▌         | 5/100 [00:03<01:00,  1.57it/s]

Epoch 4, Loss: 0.11846086039708538


  6%|▌         | 6/100 [00:03<00:57,  1.63it/s]

Epoch 5, Loss: 0.1151481900499993


  7%|▋         | 7/100 [00:04<01:00,  1.53it/s]

Epoch 6, Loss: 0.11232827894086833


  8%|▊         | 8/100 [00:05<01:00,  1.53it/s]

Epoch 7, Loss: 0.11060364753680935


  9%|▉         | 9/100 [00:05<00:57,  1.58it/s]

Epoch 8, Loss: 0.10857107292501059


 10%|█         | 10/100 [00:06<00:55,  1.61it/s]

Epoch 9, Loss: 0.10744364134400648


 11%|█         | 11/100 [00:07<00:54,  1.63it/s]

Epoch 10, Loss: 0.10564369472251546


 12%|█▏        | 12/100 [00:07<00:53,  1.65it/s]

Epoch 11, Loss: 0.10416183612722962


 13%|█▎        | 13/100 [00:08<00:52,  1.67it/s]

Epoch 12, Loss: 0.10359780067965349


 14%|█▍        | 14/100 [00:08<00:51,  1.67it/s]

Epoch 13, Loss: 0.10200488259329735


 15%|█▌        | 15/100 [00:09<00:50,  1.68it/s]

Epoch 14, Loss: 0.10070980265673575


 16%|█▌        | 16/100 [00:10<00:50,  1.67it/s]

Epoch 15, Loss: 0.09960412943064618


 17%|█▋        | 17/100 [00:10<00:49,  1.68it/s]

Epoch 16, Loss: 0.09837009319869261


 18%|█▊        | 18/100 [00:11<00:48,  1.68it/s]

Epoch 17, Loss: 0.09669802224363894


 19%|█▉        | 19/100 [00:11<00:48,  1.67it/s]

Epoch 18, Loss: 0.09353708332842846


 20%|██        | 20/100 [00:12<00:48,  1.66it/s]

Epoch 19, Loss: 0.0838081727667643


 21%|██        | 21/100 [00:13<00:47,  1.65it/s]

Epoch 20, Loss: 0.0597658437352296


 22%|██▏       | 22/100 [00:13<00:46,  1.67it/s]

Epoch 21, Loss: 0.049506035034323735


 23%|██▎       | 23/100 [00:14<00:47,  1.61it/s]

Epoch 22, Loss: 0.04454924708920433


 24%|██▍       | 24/100 [00:14<00:46,  1.63it/s]

Epoch 23, Loss: 0.04170043615266938


 25%|██▌       | 25/100 [00:15<00:46,  1.63it/s]

Epoch 24, Loss: 0.03933251316639278


 26%|██▌       | 26/100 [00:16<00:45,  1.64it/s]

Epoch 25, Loss: 0.03740198344410293


 27%|██▋       | 27/100 [00:16<00:44,  1.65it/s]

Epoch 26, Loss: 0.03652795396709255


 28%|██▊       | 28/100 [00:17<00:43,  1.64it/s]

Epoch 27, Loss: 0.035168285302019875


 29%|██▉       | 29/100 [00:17<00:42,  1.65it/s]

Epoch 28, Loss: 0.03395282345774049


 30%|███       | 30/100 [00:18<00:42,  1.66it/s]

Epoch 29, Loss: 0.03234519446980732


 31%|███       | 31/100 [00:19<00:41,  1.66it/s]

Epoch 30, Loss: 0.03183009851019175


 32%|███▏      | 32/100 [00:19<00:40,  1.67it/s]

Epoch 31, Loss: 0.03105476875765605


 33%|███▎      | 33/100 [00:20<00:39,  1.68it/s]

Epoch 32, Loss: 0.03004294068076353


 34%|███▍      | 34/100 [00:20<00:39,  1.69it/s]

Epoch 33, Loss: 0.029202318772142136


 35%|███▌      | 35/100 [00:21<00:38,  1.70it/s]

Epoch 34, Loss: 0.028505068126049842


 36%|███▌      | 36/100 [00:22<00:37,  1.69it/s]

Epoch 35, Loss: 0.028205564419958494


 37%|███▋      | 37/100 [00:22<00:37,  1.69it/s]

Epoch 36, Loss: 0.027626734127059883


 38%|███▊      | 38/100 [00:23<00:36,  1.68it/s]

Epoch 37, Loss: 0.026910375068583594


 39%|███▉      | 39/100 [00:23<00:36,  1.67it/s]

Epoch 38, Loss: 0.026609358736283534


 40%|████      | 40/100 [00:24<00:37,  1.60it/s]

Epoch 39, Loss: 0.025787833594599276


 41%|████      | 41/100 [00:25<00:36,  1.60it/s]

Epoch 40, Loss: 0.0257436256807175


 42%|████▏     | 42/100 [00:25<00:35,  1.62it/s]

Epoch 41, Loss: 0.025072366367062483


 43%|████▎     | 43/100 [00:26<00:35,  1.62it/s]

Epoch 42, Loss: 0.025244241303976973


 44%|████▍     | 44/100 [00:27<00:34,  1.62it/s]

Epoch 43, Loss: 0.024756831244699386


 45%|████▌     | 45/100 [00:27<00:34,  1.61it/s]

Epoch 44, Loss: 0.02419775157723253


 46%|████▌     | 46/100 [00:28<00:33,  1.62it/s]

Epoch 45, Loss: 0.024564701469509197


 47%|████▋     | 47/100 [00:28<00:32,  1.62it/s]

Epoch 46, Loss: 0.023505786809439665


 48%|████▊     | 48/100 [00:29<00:32,  1.62it/s]

Epoch 47, Loss: 0.023070025148757542


 49%|████▉     | 49/100 [00:30<00:31,  1.64it/s]

Epoch 48, Loss: 0.02307693907185045


 50%|█████     | 50/100 [00:30<00:30,  1.64it/s]

Epoch 49, Loss: 0.02235032345741781


 51%|█████     | 51/100 [00:31<00:29,  1.65it/s]

Epoch 50, Loss: 0.022437851291552163


 52%|█████▏    | 52/100 [00:31<00:28,  1.68it/s]

Epoch 51, Loss: 0.021905360526219917


 53%|█████▎    | 53/100 [00:32<00:28,  1.67it/s]

Epoch 52, Loss: 0.02204972719204285


 54%|█████▍    | 54/100 [00:33<00:27,  1.67it/s]

Epoch 53, Loss: 0.021332435095575725


 55%|█████▌    | 55/100 [00:33<00:26,  1.67it/s]

Epoch 54, Loss: 0.020812811328941618


 56%|█████▌    | 56/100 [00:34<00:27,  1.59it/s]

Epoch 55, Loss: 0.020811368789447306


 57%|█████▋    | 57/100 [00:34<00:26,  1.60it/s]

Epoch 56, Loss: 0.02015131635667643


 58%|█████▊    | 58/100 [00:35<00:26,  1.59it/s]

Epoch 57, Loss: 0.020698997480751703


 59%|█████▉    | 59/100 [00:36<00:25,  1.61it/s]

Epoch 58, Loss: 0.020320390635799403


 60%|██████    | 60/100 [00:36<00:24,  1.64it/s]

Epoch 59, Loss: 0.019969034064135843


 61%|██████    | 61/100 [00:37<00:24,  1.62it/s]

Epoch 60, Loss: 0.019635761878461302


 62%|██████▏   | 62/100 [00:38<00:23,  1.60it/s]

Epoch 61, Loss: 0.020111310181285275


 63%|██████▎   | 63/100 [00:38<00:22,  1.63it/s]

Epoch 62, Loss: 0.019557769415269224


 64%|██████▍   | 64/100 [00:39<00:22,  1.63it/s]

Epoch 63, Loss: 0.019302255151452296


 65%|██████▌   | 65/100 [00:39<00:21,  1.64it/s]

Epoch 64, Loss: 0.01901480450065081


 66%|██████▌   | 66/100 [00:40<00:20,  1.67it/s]

Epoch 65, Loss: 0.01835345470166806


 67%|██████▋   | 67/100 [00:41<00:19,  1.69it/s]

Epoch 66, Loss: 0.01833423877436326


 68%|██████▊   | 68/100 [00:41<00:18,  1.70it/s]

Epoch 67, Loss: 0.01838170492965366


 69%|██████▉   | 69/100 [00:42<00:18,  1.71it/s]

Epoch 68, Loss: 0.018758844989314634


 70%|███████   | 70/100 [00:42<00:17,  1.72it/s]

Epoch 69, Loss: 0.0181914026050424


 71%|███████   | 71/100 [00:43<00:16,  1.72it/s]

Epoch 70, Loss: 0.018100746012523653


 72%|███████▏  | 72/100 [00:44<00:17,  1.64it/s]

Epoch 71, Loss: 0.017685385381540054


 73%|███████▎  | 73/100 [00:44<00:16,  1.68it/s]

Epoch 72, Loss: 0.017909554412446436


 74%|███████▍  | 74/100 [00:45<00:15,  1.72it/s]

Epoch 73, Loss: 0.017390799825548805


 75%|███████▌  | 75/100 [00:45<00:14,  1.75it/s]

Epoch 74, Loss: 0.017354778728633117


 76%|███████▌  | 76/100 [00:46<00:13,  1.76it/s]

Epoch 75, Loss: 0.01732678308299864


 77%|███████▋  | 77/100 [00:46<00:13,  1.77it/s]

Epoch 76, Loss: 0.01705556586249497


 78%|███████▊  | 78/100 [00:47<00:12,  1.77it/s]

Epoch 77, Loss: 0.01731205303753781


 79%|███████▉  | 79/100 [00:47<00:11,  1.78it/s]

Epoch 78, Loss: 0.017215256136282667


 80%|████████  | 80/100 [00:48<00:11,  1.78it/s]

Epoch 79, Loss: 0.01692278580551541


 81%|████████  | 81/100 [00:49<00:10,  1.78it/s]

Epoch 80, Loss: 0.016766428225387663


 82%|████████▏ | 82/100 [00:49<00:10,  1.79it/s]

Epoch 81, Loss: 0.01646935903403167


 83%|████████▎ | 83/100 [00:50<00:09,  1.79it/s]

Epoch 82, Loss: 0.015952971124913076


 84%|████████▍ | 84/100 [00:50<00:08,  1.80it/s]

Epoch 83, Loss: 0.016287554568854164


 85%|████████▌ | 85/100 [00:51<00:08,  1.79it/s]

Epoch 84, Loss: 0.01601384115887048


 86%|████████▌ | 86/100 [00:51<00:07,  1.78it/s]

Epoch 85, Loss: 0.016444325799820424


 87%|████████▋ | 87/100 [00:52<00:07,  1.78it/s]

Epoch 86, Loss: 0.016110178445879707


 88%|████████▊ | 88/100 [00:52<00:06,  1.79it/s]

Epoch 87, Loss: 0.015772204789763104


 89%|████████▉ | 89/100 [00:53<00:06,  1.71it/s]

Epoch 88, Loss: 0.015210556597409275


 90%|█████████ | 90/100 [00:54<00:05,  1.73it/s]

Epoch 89, Loss: 0.01524099208192553


 91%|█████████ | 91/100 [00:54<00:05,  1.74it/s]

Epoch 90, Loss: 0.015399384919639424


 92%|█████████▏| 92/100 [00:55<00:04,  1.75it/s]

Epoch 91, Loss: 0.015010314700353302


 93%|█████████▎| 93/100 [00:55<00:03,  1.76it/s]

Epoch 92, Loss: 0.014920353126080097


 94%|█████████▍| 94/100 [00:56<00:03,  1.76it/s]

Epoch 93, Loss: 0.014943296457609093


 95%|█████████▌| 95/100 [00:56<00:02,  1.77it/s]

Epoch 94, Loss: 0.014855898805164624


 96%|█████████▌| 96/100 [00:57<00:02,  1.77it/s]

Epoch 95, Loss: 0.014958414973592243


 97%|█████████▋| 97/100 [00:58<00:01,  1.76it/s]

Epoch 96, Loss: 0.01493444322770882


 98%|█████████▊| 98/100 [00:58<00:01,  1.73it/s]

Epoch 97, Loss: 0.014467970398983194


 99%|█████████▉| 99/100 [00:59<00:00,  1.67it/s]

Epoch 98, Loss: 0.0142362427439155


100%|██████████| 100/100 [00:59<00:00,  1.67it/s]

Epoch 99, Loss: 0.015023289815952325





SpaceInvLearner(
  (fc1): Linear(in_features=33600, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc_out): Linear(in_features=256, out_features=6, bias=True)
)

In [8]:
learner.load_state_dict(torch.load("models/bc_learner.pth"), strict=True)
total_learner_reward = 0
done = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for i in range(20):
    obs = env.reset()
    done = False
    while not done:
        with torch.no_grad():
            action = learner.get_action(torch.Tensor([obs]).to(device))
        obs, reward, done, info = env.step(action)
        total_learner_reward += reward
        if done:
            break

print(total_learner_reward/20)

visualize(learner, env, "bc_learner")

Visualizing


  action = learner.get_action(torch.Tensor([obs]).to(device))


Video saved as bc_learner.avi
Reward: 80.0


## LOAD EXPERT

In [9]:
from expert.ppo import PPOAgent, ActorCnn, CriticCnn

INPUT_SHAPE = (4, 84, 84)
ACTION_SIZE = env.action_space.n
SEED = 0
GAMMA = 0.99           # discount factor
ALPHA= 0.00001         # Actor learning rate
BETA = 0.00001          # Critic learning rate
TAU = 0.95
BATCH_SIZE = 64
PPO_EPOCH = 10
CLIP_PARAM = 0.2
UPDATE_EVERY = 1000    # how often to update the network 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
agent = PPOAgent(INPUT_SHAPE, ACTION_SIZE, SEED, device, GAMMA, ALPHA, BETA, TAU, UPDATE_EVERY, BATCH_SIZE, PPO_EPOCH, CLIP_PARAM, ActorCnn(INPUT_SHAPE, ACTION_SIZE), CriticCnn(INPUT_SHAPE))
agent.load_model("models/expert_actor.pth", device)

## DAgger Implementation

In [10]:
import dagger

dagger.interact(env, learner, agent, observations=[], actions=[], checkpoint_path="models/DAgger.pth", seed=seed, num_epochs=40, tqdm_disable=True)

After interaction 0, reward = 80.0
Training the learner
Training for 40 epochs
Epoch 0, Loss: 0.2285896447943706
Epoch 1, Loss: 0.16880559284313051
Epoch 2, Loss: 0.15796418084817773
Epoch 3, Loss: 0.15377965417562747
Epoch 4, Loss: 0.1483202627476524
Epoch 5, Loss: 0.14525875747203826
Epoch 6, Loss: 0.14375673450675666
Epoch 7, Loss: 0.14360926916786268
Epoch 8, Loss: 0.14317754314226264
Epoch 9, Loss: 0.14113674181349137
Epoch 10, Loss: 0.13915853751640694
Epoch 11, Loss: 0.13925927380720773
Epoch 12, Loss: 0.13707467609760807
Epoch 13, Loss: 0.13588255976929384
Epoch 14, Loss: 0.13434492244439966
Epoch 15, Loss: 0.13368839709197775
Epoch 16, Loss: 0.1337788970447054
Epoch 17, Loss: 0.13292337802110935
Epoch 18, Loss: 0.1343384880061243
Epoch 19, Loss: 0.13505621949831645
Epoch 20, Loss: 0.13200805456030604
Epoch 21, Loss: 0.13258715593347362
Epoch 22, Loss: 0.1319811149555094
Epoch 23, Loss: 0.13122459562385783
Epoch 24, Loss: 0.12987088315627154
Epoch 25, Loss: 0.12836752119017583


KeyboardInterrupt: 

In [None]:
learner.load_state_dict(torch.load("models/DAgger.pth"), strict=True)
total_learner_reward = 0
done = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for i in range(20):
    obs = env.reset()
    done = False
    while not done:
        with torch.no_grad():
            action = learner.get_action(torch.Tensor([obs]).to(device))
        obs, reward, done, info = env.step(action)
        total_learner_reward += reward

print(total_learner_reward/20)

visualize(learner, env, "dagger_learner.avi")

118.25
