In [1]:
import torch
import torch.nn as nn
import torch.functional as F
from utils import *

# Neural Networks
class SequenceModel(nn.Module):
    def __init__(self, representationSize, recurrentStateSize, actionSize):
        super().__init__()
        self.recurrent = nn.GRUCell(representationSize + actionSize, recurrentStateSize)

    def forward(self, representation, recurrentState, action):
        recurrentState = self.recurrent(torch.cat((representation, action), -1), recurrentState)
        return recurrentState

class PriorNet(nn.Module):
    def __init__(self, inputSize, representationSize=16):
        super().__init__()
        self.representationSize = representationSize
        self.mlp = sequentialModel1D(inputSize, [256, 256], representationSize**2)
    
    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, self.representationSize, self.representationSize)
        _, indices = torch.max(x, dim=-1)
        representation = F.one_hot(indices, num_classes=self.representationSize)
        return representation
    
class PosteriorNet(nn.Module):
    def __init__(self, inputSize, representationSize=16):
        super().__init__()
        self.representationSize = representationSize
        self.mlp = sequentialModel1D(inputSize, [256, 256], representationSize**2)
    
    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, self.representationSize, self.representationSize)
        _, indices = torch.max(x, dim=-1)
        representation = F.one_hot(indices, num_classes=self.representationSize)
        return representation

class ConvEncoder(nn.Module):
    def __init__(self, inputShape, outputSize):
        super(ConvEncoder, self).__init__()
        c, h, w = inputShape[0]
        self.convolutionalNet = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=4, stride=2, padding=1),  # Output: (32, h/2, w/2)
            nn.Tanh(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # Output: (64, h/4, w/4)
            nn.Tanh(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: (128, h/8, w/8)
            nn.Tanh(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Output: (256, h/16, w/16)
            nn.Tanh(),
            nn.Flatten(),
            nn.Linear(256 * (h // 16) * (w // 16), outputSize),
            nn.Tanh(),
        )

    def forward(self, obs):
        return self.convolutionalNet(obs/255.0)

class ConvDecoder(nn.Module):
    def __init__(self, inputSize, outputShape):
        super(ConvDecoder, self).__init__()
        self.outputShape = outputShape
        c, h, w = outputShape
        self.fc = nn.Sequential(
            nn.Linear(inputSize, 256 * (h // 16) * (w // 16)),
            nn.Tanh(),
        )
        self.deconvolutionalNet = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Output: (128, h/8, w/8)
            nn.Tanh(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output: (64, h/4, w/4)
            nn.Tanh(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Output: (32, h/2, w/2)
            nn.Tanh(),
            nn.ConvTranspose2d(32, c, kernel_size=4, stride=2, padding=1),  # Output: (c, h, w)
            nn.Sigmoid(),  # Output pixel values between 0 and 1
        )

    def forward(self, x):
        x = self.fc(x)
        batch_size = x.size(0)
        c, h, w = 256, self.obs_shape[1] // 16, self.obs_shape[2] // 16
        x = x.view(batch_size, c, h, w)
        return self.deconvolutionalNet(x)


In [1]:
import torch
import torch.nn as nn
import torch.functional as F
from utils import *
import pickle

with open('episode_0_data.pkl', 'rb') as f:
    data = pickle.load(f)
observations = data['observations']
actions = data['actions']



[array([-0.35423192,  0.79713292,  0.18345659]),
 array([0.75210184, 3.07461858, 0.11055832]),
 array([-0.46157223,  2.63887912,  0.16531258]),
 array([0.75784761, 2.70852089, 0.00635755]),
 array([0.54102767, 2.24330783, 0.10689338]),
 array([0.01920862, 3.00895542, 0.16531166]),
 array([-0.44042847,  4.15858865,  0.08983862]),
 array([-0.4306961 ,  0.5545434 ,  0.15010509]),
 array([0.17087352, 0.7632152 , 0.13232936]),
 array([0.30589247, 3.78691018, 0.10369731]),
 array([-0.07548559,  2.66578197,  0.1701715 ]),
 array([-0.38932347,  0.21948475,  0.09831232]),
 array([0.92900676, 0.15320848, 0.13708094]),
 array([0.35096866, 2.02336371, 0.06343058]),
 array([-0.72261387,  1.84531897,  0.03441065]),
 array([0.10123158, 2.99547464, 0.16735274]),
 array([0.27051264, 2.91375309, 0.05943009]),
 array([-0.12070028,  0.22517685,  0.14118423]),
 array([-0.48510417,  0.55613317,  0.17738602]),
 array([0.28777924, 0.84633544, 0.03769388]),
 array([-0.54765284,  2.27790877,  0.01479488]),
 arr