In [2]:
!pip install chamferdist

Collecting chamferdist
  Downloading chamferdist-1.0.0.tar.gz (16 kB)
Building wheels for collected packages: chamferdist
  Building wheel for chamferdist (setup.py) ... [?25l[?25hdone
  Created wheel for chamferdist: filename=chamferdist-1.0.0-cp37-cp37m-linux_x86_64.whl size=5632650 sha256=d9341e3526f1384ad715c615d54bfccc733bfb8ad01732dd17e78ed65fc2457a
  Stored in directory: /root/.cache/pip/wheels/28/bb/d1/c789ecd6835e466e813f6e2c5e23bb1bbb2248e84586ba82d2
Successfully built chamferdist
Installing collected packages: chamferdist
Successfully installed chamferdist-1.0.0


In [3]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np

from chamferdist import ChamferDistance

%matplotlib inline
import matplotlib.pyplot as plt

import plotly.express as px
import plotly.graph_objects as go

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
class FC(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.layer1 = nn.Linear(
            in_features=kwargs["input_shape"], out_features=256
        )
        self.layer2 = nn.Linear(
            in_features=256, out_features=256
        )
        self.layer3 = nn.Linear(
            in_features=256, out_features=128
        )
        self.layer4 = nn.Linear(
            in_features=128, out_features=64
        )
        self.layer5 = nn.Linear(
            in_features=64, out_features=kwargs["output_shape"]
        )

    def forward(self, features):
        B = features.shape[0]
        features = features.reshape((B, -1))
        features = self.layer1(features)
        features = torch.relu(features)
        features = self.layer2(features)
        features = torch.relu(features)
        features = self.layer3(features)
        features = torch.relu(features)
        features = self.layer4(features)
        features = torch.relu(features)
        features = self.layer5(features)
        return features

In [6]:
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat


class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

In [7]:
class AE_mlp_mlp(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.enc = FC(input_shape = kwargs["input_shape"], output_shape = 6)
        self.dec = FC(input_shape = 6, output_shape = kwargs["input_shape"])

    def forward(self, features):
        B, N, _ = features.shape
        features = self.enc(features)
        features = self.dec(features)
        features = features.reshape((B, N, -1))
        return features

In [8]:
class AE_pn_enc(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.pn = PointNetfeat()
        self.enc = FC(input_shape = 1024, output_shape = 6)

    def forward(self, features):
        B, N, _ = features.shape
        features = features.transpose(1,2)
        features = self.pn(features)[0]
        features = self.enc(features)
        return features

In [9]:
class AE_pn_mlp(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.pn = PointNetfeat()
        self.enc = FC(input_shape = 1024, output_shape = 6)
        self.dec = FC(input_shape = 6, output_shape = kwargs["output_shape"])

    def forward(self, features):
        B, N, _ = features.shape
        features = features.transpose(1,2)
        features = self.pn(features)[0]
        features = self.enc(features)
        features = self.dec(features)
        features = features.reshape((B, N, 3))
        return features

In [10]:
def sample_batch(X, num_train, batch_size):
    """
    Sample batch_size elements from the training data and their
    corresponding labels to use in this round of gradient descent.
    """
    idx = torch.randint(0, num_train, (batch_size,))
    X_batch = X[idx]
    return X_batch

def sample_batch(X, y, num_train, batch_size):
    """
    Sample batch_size elements from the training data and their
    corresponding labels to use in this round of gradient descent.
    """
    idx = torch.randint(0, num_train, (batch_size,))
    X_batch = X[idx]
    Y_batch = y[idx]
    return X_batch, Y_batch

In [12]:
# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float
print("device:", device)

# load dataset
path = "/content/drive/MyDrive/research/bretl/"
pts = torch.from_numpy(np.load(path+"points_50.npy")).to(device, dtype=dtype)
a = torch.from_numpy(np.load(path+"a_50.npy")).to(device, dtype=dtype)

# small dataset
# pts = pts[:2]
# a = a[:2]

N, n, _ = pts.shape

# pts = pts[:, 0:-1, :] - pts[:, 1:, :]
# order = torch.arange(n-1).to(device, dtype=dtype).repeat(N, 1, 1).transpose(1,2)
# print(order.shape)
# pts = torch.cat((pts, order), 2)
# print(pts[0])

# model = AE_mlp_mlp(input_shape=60).to(device)
# model = AE_pn_mlp(output_shape=50*3).to(device)
model = AE_pn_enc().to(device)
# model = FC(input_shape = 6, output_shape = 50*3).to(device)
# model = FC(input_shape=50*3, output_shape=6).to(device)

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# mean-squared error loss
criterion = nn.MSELoss()
# criterion = ChamferDistance()

# split train and test
train_pts = pts[0:N*4//5]
train_a = a[0:N*4//5]
test_pts = pts[N*4//5:N]
test_a = a[N*4//5:N]
num_train = train_pts.shape[0]
num_test = test_pts.shape[0]

device: cuda


In [None]:
epochs = 1000
batch_size = 128
for epoch in range(epochs):
    loss = 0
    # batch_features = sample_batch(train_pts, num_train, batch_size)
    batch_features, batch_a = sample_batch(train_pts, train_a, num_train, batch_size)
    
    # reset the gradients back to zero
    # PyTorch accumulates gradients on subsequent backward passes
    optimizer.zero_grad()
    
    # compute reconstructions
    outputs = model(batch_features)
    # outputs = model(batch_a)
    
    # compute training reconstruction loss
    train_loss = criterion(outputs, batch_a)
    # train_loss = criterion(outputs, batch_features.reshape(batch_size, -1))

    
    # compute accumulated gradients
    train_loss.backward()
    
    # perform parameter update based on current gradients
    optimizer.step()
    
    # add the mini-batch training loss to epoch loss
    loss = train_loss.item()

    # display the iteration training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

epoch : 1/1000, loss = 219.949753
epoch : 2/1000, loss = 199.421478
epoch : 3/1000, loss = 224.224884
epoch : 4/1000, loss = 196.517380
epoch : 5/1000, loss = 205.612976
epoch : 6/1000, loss = 213.918823
epoch : 7/1000, loss = 215.956467
epoch : 8/1000, loss = 204.093491
epoch : 9/1000, loss = 216.620773
epoch : 10/1000, loss = 209.036591
epoch : 11/1000, loss = 201.991989
epoch : 12/1000, loss = 224.080444
epoch : 13/1000, loss = 236.118011
epoch : 14/1000, loss = 190.665085
epoch : 15/1000, loss = 216.362640
epoch : 16/1000, loss = 199.730209
epoch : 17/1000, loss = 186.904083
epoch : 18/1000, loss = 209.304565
epoch : 19/1000, loss = 201.504089
epoch : 20/1000, loss = 213.481415
epoch : 21/1000, loss = 208.130722
epoch : 22/1000, loss = 205.917450
epoch : 23/1000, loss = 191.056641
epoch : 24/1000, loss = 202.256714
epoch : 25/1000, loss = 223.763519
epoch : 26/1000, loss = 189.094574
epoch : 27/1000, loss = 185.664490
epoch : 28/1000, loss = 193.923065
epoch : 29/1000, loss = 187.9

In [None]:
torch.save(model.state_dict(), path+"ae_pn_mlp_enc.pt")

In [13]:
dec = FC(input_shape = 6, output_shape = 50*3).to(device)
dec.load_state_dict(torch.load(path+"ae_pn_mlp_dec.pt"))

<All keys matched successfully>

In [14]:
enc = AE_pn_enc().to(device)
enc.load_state_dict(torch.load(path+"ae_pn_mlp_enc.pt"))

<All keys matched successfully>

In [15]:
model = AE_pn_mlp(output_shape=50*3).to(device)
model.load_state_dict(torch.load(path+"ae_pn_mlp.pt"))

<All keys matched successfully>

In [None]:
# eval
# outputs = dec(enc(test_pts))
outputs = model(test_pts)
print(criterion(outputs.reshape(num_test, -1), test_pts.reshape(num_test, -1)))

tensor(0.0033, device='cuda:0', grad_fn=<MseLossBackward0>)


In [17]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.enc.register_forward_hook(get_activation('enc'))
outputs = model(test_pts)
code = activation['enc']
print(criterion(code, test_a))

tensor(241.8052, device='cuda:0')


In [20]:
# visualize code distribution
code_np = code.detach().cpu().numpy()
test_a_np = test_a.detach().cpu().numpy()
fig = go.Figure(data=[go.Scatter3d(x=code_np[:, 0], y=code_np[:, 1], z=code_np[:, 2],
                                   mode='markers'),
                      go.Scatter3d(x=test_a_np[:, 0], y=test_a_np[:, 1], z=test_a_np[:, 2],
                                   mode='markers')])
fig.show()

In [None]:
recon_np = recon.cpu().detach().numpy()
test_np = test_features.cpu().detach().numpy()
vis_idx = np.random.randint(0, num_test)
print(test_np[vis_idx])
print(recon_np[vis_idx])

fig = go.Figure(data=[go.Scatter3d(x=test_np[vis_idx, :, 0], y=test_np[vis_idx, :, 1], z=test_np[vis_idx, :, 2],
                                   mode='markers'),
                      go.Scatter3d(x=recon_np[vis_idx, :, 0], y=recon_np[vis_idx, :, 1], z=recon_np[vis_idx, :, 2],
                                   mode='markers')])
fig.show()

[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 2.04048567e-02  2.06336335e-04 -2.19152629e-04]
 [ 4.07966115e-02  8.68835370e-04 -6.88027008e-04]
 [ 6.11661933e-02  2.05242936e-03 -1.12375803e-03]
 [ 8.14952701e-02  3.82045819e-03 -1.24441064e-03]
 [ 1.01751700e-01  6.23260112e-03 -7.72650470e-04]
 [ 1.21872567e-01  9.34498571e-03  5.67190466e-04]
 [ 1.41755745e-01  1.32059315e-02  3.04487604e-03]
 [ 1.61249697e-01  1.78420991e-02  6.89851167e-03]
 [ 1.80152953e-01  2.32580174e-02  1.23336669e-02]
 [ 1.98214933e-01  2.94360593e-02  1.95229575e-02]
 [ 2.15156823e-01  3.63268256e-02  2.85756886e-02]
 [ 2.30630144e-01  4.38348241e-02  3.95287164e-02]
 [ 2.44296849e-01  5.18357828e-02  5.23456484e-02]
 [ 2.55862206e-01  6.01796247e-02  6.69073910e-02]
 [ 2.65074760e-01  6.86904639e-02  8.30121711e-02]
 [ 2.71726429e-01  7.71666095e-02  1.00375526e-01]
 [ 2.75657624e-01  8.53850171e-02  1.18638076e-01]
 [ 2.76852608e-01  9.31605324e-02  1.37451425e-01]
 [ 2.75419265e-01  1.00340202e-