In [None]:
from dataloader import AtariDataset
import gym
import torch.nn as nn
import torch
import numpy as np
import random
import tqdm
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim import optimizer
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay

## SEEDING

In [None]:
def reseed(seed):
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)

reseed(42)

## LOAD DATA

In [None]:
dataloader = AtariDataset("atari_v1")
observations, actions = dataloader.compile_data()


## MAKE ENVIRONMENT

In [None]:
def make_env(env_id, seed=25):
    env = gym.make(env_id, obs_type='grayscale', render_mode=None)
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env
env = make_env("ALE/SpaceInvaders-v5")
print(env.action_space.n)
print(env.observation_space.shape)



## Train BC

In [None]:
from model import SpaceInvLearner
import bc

learner = SpaceInvLearner(env)

bc.train(learner=learner, observations=observations, checkpoint_path="models/bc_learner.pth", actions=actions, num_epochs=10)

In [None]:
total_learner_reward = 0
done = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
obs = env.reset()
while not done:
    with torch.no_grad():
        action = learner.get_action(torch.Tensor([obs]).to(device))
    obs, reward, done, info = env.step(action)
    total_learner_reward += reward
    if done:
        break

print(total_learner_reward)