# our network

In [1]:
import torch
import torch.nn as nn

In [2]:
# LSTM : Warper



In [37]:
# MLP : SDF Decoder

class MLP(nn.Module):
    def __init__(self, layers, wn_layers, weight_norm, dropout_layers, dropout_prob):
        """
            Initialize the MLP that extracts the template SDF
        """
        
        # layers: number of layers + size of each layer
        # [256, 256, 256, 256, 256] 
        # => 5 layers with 256 neurons each
        
        # wn_layers: layer indices in which normalization is used
        # [0, 1, 2, 3, 4]
        
        # weight_norm: bool
        # which normalization to use????? TODO
        
        # dropout_layers: layer indices in which dropout is used
        # [0, 1, 2, 3, 4]
        
        # dropout_prob: probability for dropout
        # 0.05
        
        super(MLP, self).__init__()
        
        
        self.numlayers = len(layers)+1
        print("numlayers", self.numlayers)
        
        in_dim = 3
        out_dim = 1
        
        self.layers = nn.ModuleList()
        for i in range(self.numlayers):
            print(i)
            # layer input feature count (for first layer: in_dim)
            in_features = in_dim if (i == 0) else layers[i-1]
            # layer output feature count (for last layer: out_dim)
            out_features = out_dim if (i == (len(layers))) else layers[i]
            
            # fully connected layer
            layer = nn.Linear(in_features, out_features)
            
            modules = [layer]
            if i in wn_layers:
                # weight normalization layer
                if False:  ##########weight_norm: ########## TODO
                    layer = nn.utils.weight_norm(layer) # ????????
                    modules[0] = layer
                else:
                    modules.append(nn.LayerNorm(out_features))
            
            # activation (tanh if last layer, else relu)
            activation = nn.Tanh() if (i == (len(layers))) else nn.ReLU()
            modules.append(activation)
            
            # dropout
            if i in dropout_layers: ### TODO only if training??
                modules.append(nn.Dropout(dropout_prob))
            
            
            # if last layer (and tanh) => another tanh??
            if (i == (len(layers))): #### TODO and tanh
                modules.append(nn.Tanh())
            
            sequential = nn.Sequential(*modules)
            self.layers.append(sequential)
            
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
        return x

In [38]:
dims = [256,256,256,256,256]
dropout =[0,1,2,3,4]
dropout_prob = 0.05
norm_layer = [0,1,2,3,4]
xyz_in_all = False
weight_norm = True

mlp = MLP(dims, norm_layer, weight_norm, dropout, dropout_prob)

numlayers 6
0
1
2
3
4
5


In [40]:
mlp

MLP(
  (layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=3, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Dropout(p=0.05, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Dropout(p=0.05, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Dropout(p=0.05, inplace=False)
    )
    (3): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Dropout(p=0.05, inplace=False)
    )
    (4): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((25

In [39]:
x_test = torch.ones(3,3)
print('input shape: ', x_test.shape, '| dtype: ', x_test.dtype)

output = mlp(x_test)
print('output shape: ', output.shape, '| dtype: ', output.dtype)
print(output)

input shape:  torch.Size([3, 3]) | dtype:  torch.float32
output shape:  torch.Size([3, 1]) | dtype:  torch.float32
tensor([[-0.2518],
        [-0.5215],
        [-0.2699]], grad_fn=<TanhBackward>)


In [None]:
# Decoder : warper + sdf_decoder

In [None]:
# collect data for training/validation

In [47]:
# loss functions

def curriculum_training_loss(data, e, l):
    return 1 ####### TODO
def rec_loss_with_warping_steps(s, K, N, e, l):
    loss = 0
    for k in range(1, K+1): # for each shape
        for i in range(1, N+1): # for each SDF sample (of one shape)
            data = 1 ### TODO # warped point AND ground truth sdf value
            loss += curriculum_training_loss(data, e,l)
    return loss
            

def reconstruction_loss():
    # progressive reconstruction loss
    #warping steps
    rec_loss = 0
    for s in (2,4,6,8):
        rec_loss += rec_loss_with_warping_steps(s, 1, 1, 1, 1) ##### TODO
    return rec_loss

def regularization_loss():
    # point-wise regularization
    #### TODO
    #point pair regularization
    #### TODO
    return 1
    
def loss():
    return reconstruction_loss() + regularization_loss()

In [None]:
# Training

In [None]:
# Test

In [9]:
x_test = torch.randn(5, 6)

In [46]:
loss()

5