In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# Things that doesn't yet seem to be supported in pytorch:
# padding = same
# spatial softmax
class AutoEncoder_Dynamics(nn.Module):
    def __init__(self, img_res, z_dim, u_dim):
        super(AutoEncoder_Dynamics, self).__init__()
        
        self.img_res = img_res
        self.x_dim = img_res*img_res
        self.z_dim = z_dim
        self.u_dim = u_dim
        
        self.encoder = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, padding=2)), # kernel_size different than original
            ('relu1', nn.Relu()),
            ('conv2', nn.Conv2d(8, 8, 5, padding=2)),
            ('relu2', nn.Relu()),
            ('conv3', nn.Conv2d(8, 8, 5, padding=2)),
            ('relu3', nn.Relu()),
#             ('softmax', nn.softmax()),   # TODO: Implment and insert spatial softmax here...
            ('fc1', nn.Linear(8*self.x_dim, 256)),
            ('relu4', nn.Relu()),
            ('dropout1', nn.Dropout(p=0.5)),
            ('fc2', nn.Linear(256, 256)),
            ('relu5', nn.Relu()),
            ('fc3', nn.Linear(256, self.z_dim))
        ]))
        self.dynamics = nn.Sequential(OrderedDict([
            ('d_fc1', nn.Linear(self.z_dim + self.u_dim, 128)),
            ('d_relu1', nn.Relu()),
            ('d_dropout1', nn.Dropout(p=0.5)),
            ('d_fc2', nn.Linear(128, 128)),
            ('d_relu2', nn.Relu()),
            ('d_dropout2', nn.Dropout(p=0.5)),
            ('d_fc3', nn.Linear(128, 128)),
            ('d_relu3', nn.Relu())
            ('d_fc4', nn.Linear(128, self.z_dim)),
            ('d_relu4', nn.Relu())
        ]))
        self.decoder = nn.Sequential(OrderedDict([
            ('dec_fc1', nn.Linear(self.z_dim, 512)),
            ('dec_relu1', nn.Relu()),
            ('dec_dropout1', nn.Dropout(p=0.5)),
            ('dec_fc2', nn.Linear(512, 512)),
            ('dec_relu2', nn.Relu()),
            ('dec_dropout2', nn.Dropout(p=0.5)),
            ('dec_fc3', nn.Linear(512, 512)),
            ('dec_relu3', nn.Relu())
            ('dec_dropout3', nn.Dropout(p=0.5)),
            ('dec_fc4', nn.Linear(512, 512)),
            ('dec_relu4', nn.Relu())
        ]))
        self.environment = nn.Sequential(OrderedDict([
            ('env_conv1', nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, padding=2)), # kernel_size different than original
            ('env_relu1', nn.Relu()),
            ('env_flat1', nn.Flatten()),
            ('env_fc1', nn.Linear(4 * self.x_dim, 512)),
            ('env_relu1', nn.Relu()),
            ('env_dropout1', nn.Dropout(p=0.5)),
            ('env_fc2', nn.Linear(512, 512)),
            ('env_relu2', nn.Relu()),
            ('env_dropout2', nn.Dropout(p=0.5)),
            ('env_fc3', nn.Linear(512, 512)),
            ('env_relu3', nn.Relu())
            ('env_dropout3', nn.Dropout(p=0.5)),
            ('env_fc4', nn.Linear(512, 512)),
            ('env_relu4', nn.Relu())
        ]))
        self.last_layer = nn.Linear(512 + 512, self.x_dim)
        
    def forward(self, x_t, x_tplus, x_empty, u_t):
        '''
        x_t, x_tplus, x_empty must be of shape [N, C*H*W] where, N = batch_size, 
        C = Channels, H = Height, W = Width of image.
        u is of shape [N, D_c] where D_c = Control Dimension.
        '''
        x_full = torch.cat((x_t, x_tplus), dim=0)
        input_enc = torch.reshape(x_full, [-1, 1, self.img_res, self.img_res])
        z_full = self.encoder(input_enc)
        
        z_t = z_full[:batch_size, :]
        input_dyn = torch.cat((z_t, u_t), dim=1) #TODO: Do I have to use torch.identity after concatenation? why/not?
        z_hat_tplus = self.dynamics(input_dyn)
        
        input_dec = torch.cat((z_t, z_hat_tplus), dim=0) #TODO: Again, should I use torch.identity here?
        output_dec = self.decoder(input_dec)
        
        x_empty_full = torch.cat((x_empty, x_empty), dim=0)
        input_env = torch.reshape(x_empty_full, [-1, 1, self.img_res, self.img_res]) #TODO: identity?
        output_env = self.environment(input_env)
        
        input_last = torch.cat((output_dec, output_env), dim=1)
        x_hat_full = self.last_layer(input_last)
        
        z_hat_tplus_zero = z_hat_tplus[:, 0]
        z_hat_tplus_one = z_hat_tplus[:, 1]
        grad_zh0_zt = torch.autograd.grad(z_hat_tplus_zero, z_t, grad_outputs=torch.ones(z_hat_tplus_zero.size()), retain_graph=True)
        grad_zh1_zt = torch.autograd.grad(z_hat_tplus_one, z_t, grad_outputs=torch.ones(z_hat_tplus_one.size()), retain_graph=True)
        grad_zh0_ut = torch.autograd.grad(z_hat_tplus_zero, u_t, grad_outputs=torch.ones(z_hat_tplus_zero.size()), retain_graph=True)
        grad_zh1_ut = torch.autograd.grad(z_hat_tplus_one, u_t, grad_outputs=torch.ones(z_hat_tplus_one.size()), retain_graph=True)
        
        A = torch.stack([grad_zh0_zt, grad_zh1_zt], dim=1) # N x D_z_hat x D_z  (D_z_hat = D_z = 2)
        B = torch.stack([grad_zh0_ut, grad_zh1_ut], dim=1) # N x D_z_hat x D_c  (D_c = 2)
        c = self.__expand_dims(z_hat_tplus) - torch.bmm(A, self.__expand_dims(z_t)) - torch.bmm(B, self.__expand_dims(u_t))
        AT = torch.transpose(A, 1, 2) # Preserve the batch dimension 0 and transpose dimentions 1 and 2
        BT = torch.transpose(B, 1, 2)
        
        G = torch.bmm(A, torch.bmm(B, torch.bmm(BT, AT)))
        
    def __expand_dims(input):
        return input.unsqueeze_(input.dim())
        
class CollisionChecker(nn.Module):
    def __init__(self):
        super(CollisionChecker, self).__init__()
        # TODO Initialize Network Structure
        
    def forward(self, inputs):
        # TODO Define the forward pass
        pass