In [1]:
import gym
from gym import wrappers
from gym.spaces.utils import flatdim

import torch
from torch import nn
from torch.functional import F
from torch.utils.tensorboard import SummaryWriter
from torch.distributions.categorical import Categorical

import numpy as np
import cv2
from tqdm import tqdm

from copy import deepcopy
import mediapy
import collections

# Comment out for debugging
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# World Models
class SequenceModel(nn.Module):
    def __init__(self, h_shape, z_shape):
        self.h_shape = h_shape
        self.z_shape = z_shape

        self.model = nn.Sequential()
    
    def step(self, h_t, z_t, a_t):
        return self.model(h_t, z_t, a_t)

class Encoder(nn.Module):
    def __init__(self, h_shape, x_shape, z_shape):
        self.h_shape = h_shape
        self.x_shape = x_shape
        self.z_shape = z_shape

        self.encoder = nn.Sequential()
    
    def encode(self, h_t, x_t):
        return self.encoder(h_t, x_t)

class Decoder(nn.Module):
    def __init__(self, h_shape, z_shape, x_shape):
        self.h_shape = h_shape
        self.z_shape = z_shape
        self.x_shape = x_shape

        self.encoder = nn.Sequential()
    
    def encode(self, h_t, z_t):
        return self.encoder(h_t, z_t)

class DynamicsModel(nn.Module):
    def __init__(self, h_shape, z_shape):
        self.h_shape = h_shape
        self.z_shape = z_shape

        self.dynamics = nn.Sequential()
    
    def step(self, h_t):
        return self.dynamics(h_t)

class RewardPredictor(nn.Module):
    def __init__(self, h_shape, z_shape):
        self.h_shape = h_shape
        self.z_shape = z_shape

        self.rewards = nn.Sequential()
    
    def reward(self, h_t, z_t):
        return self.rewards(h_t, z_t)

class TerminationPredictor(nn.Module): # called "continue predictor"
    def __init__(self, h_shape, z_shape):
        self.h_shape = h_shape
        self.z_shape = z_shape

        self.terminator = nn.Sequential()

    def is_terminated(self, h_t, z_t):
        return self.terminator(h_t, z_t)

# Actor/Critic Models
class Agent(nn.Module):
    def __init__(self, env):
        self.obs_shape = env.observation_shape
        self.action_shape = env.action_shape

        self.actor = nn.Sequential()
        self.critic =nn.Sequential()
    
    def value(self, x):
        return self.critic(x)
    
    def act(self, x):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        return probs.sample()
