In [None]:
import torch
import torch.nn as nn
from utils import *
torch.set_printoptions(threshold=2000, linewidth=200)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SequenceModel(nn.Module):
    def __init__(self, representationSize, recurrentStateSize, actionSize):
        super().__init__()
        self.recurrent = nn.GRU(representationSize + actionSize, recurrentStateSize)

    def forward(self, representation, recurrentState, action):
        x = self.recurrent(torch.cat((representation, action), -1), recurrentState)
        return x
    
class DynamicsPredictor(nn.Module):
    def __init__(self, inputSize, representationSize=16):
        super().__init__()
        self.representationSize = representationSize
        self.mlp = sequentialModel1D(inputSize, [256, 256], representationSize**2)
    
    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, self.representationSize, self.representationSize)
        _, indices = torch.max(x, dim=-1)
        representation = F.one_hot(indices, num_classes=self.representationSize)
        return representation

# Rather than taking all input at once, we should convolute the raw image inputs and concatenate with recurrent state
# Encoder should also symlog process the inputs as intended
class Encoder(nn.Module):
    def __init__(self, inputSize, representationSize=16):
        super().__init__()
        self.representationSize = representationSize
        self.mlp = sequentialModel1D(inputSize, [256, 256], representationSize**2)
    
    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, self.representationSize, self.representationSize)
        _, indices = torch.max(x, dim=-1)
        representation = F.one_hot(indices, num_classes=self.representationSize)
        return representation

# I want to deconvolute the inputs using ConvTranspose2d in the future
# Decoder should symexp process the outputs
class Decoder(nn.Module):
    def __init__(self, inputSize, targetShape):
        super().__init__()
        self.targetShape = targetShape
        self.outputSize = np.prod(targetShape)
        self.mlp = sequentialModel1D(inputSize, [256, 256], self.outputSize)
    
    def forward(self, x):
        x = self.mlp(x)
        x = x.view(-1, self.representationSize, self.representationSize)
        _, indices = torch.max(x, dim=-1)
        representation = F.one_hot(indices, num_classes=self.representationSize)
        return representation

# Should be symexp twohot loss
class RewardPredictor(nn.Module):
    def __init__(self, inputSize):
        super().__init__()
        self.mlp = sequentialModel1D(inputSize, [256, 256], 1)

    def forward(self, x):
        x = self.mlp(x)
        return x
    
# Should be symexp twohot loss
class ContinuePredictor(nn.Module):
    def __init__(self, inputSize):
        super().__init__()
        self.mlp = sequentialModel1D(inputSize, [256, 256], 1)

    def forward(self, x):
        x = self.mlp(x)
        return x
    
LOG_STD_MAX = 2
LOG_STD_MIN = -5
class Actior(nn.Module):
    def __init__(self, inputSize, actionSize):
        super().__init__()
        self.mlpMean = sequentialModel1D(inputSize, [256, 256], actionSize)
        self.mlpLogStd = nn.Linear(inputSize, actionSize)


    def forward(self, x):
        x = self.mlp(x)
        actionMean = self.mlpMean(x)
        actionLogStd = self.mlpLogStd(x)
        actionLogStd = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (actionLogStd + 1) # Keeps bounds transforming range -1:1 to min:max
        actionStd = actionLogStd.exp()
        distribution = Normal(actionMean, actionStd)
        actionSample = distribution.rsample()
        actionSampleTanh = torch.tanh(actionSample)
        action = actionSampleTanh * self.continuousActionScale + self.continuousActionBias
        return action
    
class Critic(nn.Module):
    def __init__(self, inputSize):
        super().__init__()
        self.mlp = sequentialModel1D(inputSize, [256, 256], 1)

    def forward(self, x):
        x = self.mlp(x)
        return x

In [None]:
env = UnityInterface("Builds\\Windows\\Crawlerx01\\UnityEnvironment", seed=1)
print(f"{env.getSpecs()}")
behavior = env.getBehaviorNames()[0]

totalEpisodes = 10
for episode in totalEpisodes:
    decisionSteps, terminalSteps = env.getSteps(behavior)

    if len(decisionSteps) > 0:
        
        
    if len(terminalSteps) > 0: