In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class TransformationNet(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(TransformationNet, self).__init__()
        self.output_dim = output_dim

        self.conv_1 = nn.Conv1d(input_dim, 64, 1)
        self.conv_2 = nn.Conv1d(64, 128, 1)
        self.conv_3 = nn.Conv1d(128, 1024, 1)

        self.bn_1 = nn.BatchNorm1d(64)
        self.bn_2 = nn.BatchNorm1d(128)
        self.bn_3 = nn.BatchNorm1d(1024)
        self.bn_4 = nn.BatchNorm1d(512)
        self.bn_5 = nn.BatchNorm1d(256)

        self.fc_1 = nn.Linear(1024, 512)
        self.fc_2 = nn.Linear(512, 256)
        self.fc_3 = nn.Linear(256, self.output_dim*self.output_dim)

    def forward(self, x):
        num_points = x.shape[1]

        x = x.transpose(2, 1)
        x = F.relu(self.bn_1(self.conv_1(x)))
        x = F.relu(self.bn_2(self.conv_2(x)))
        x = F.relu(self.bn_3(self.conv_3(x)))

        x = nn.MaxPool1d(num_points)(x)
        x = x.view(-1, 1024)

        x = F.relu(self.bn_4(self.fc_1(x)))
        x = F.relu(self.bn_5(self.fc_2(x)))
        x = self.fc_3(x)

        identity_matrix = torch.eye(self.output_dim)
        if torch.cuda.is_available():
            identity_matrix = identity_matrix.cuda()
        x = x.view(-1, self.output_dim, self.output_dim) + identity_matrix
        return x

class BasePointNet(nn.Module):

    def __init__(self, point_dimension, return_local_features=False):
        super(BasePointNet, self).__init__()
        self.return_local_features = return_local_features
        self.input_transform = TransformationNet(input_dim=point_dimension, output_dim=point_dimension)
        self.feature_transform = TransformationNet(input_dim=64, output_dim=64)

        self.conv_1 = nn.Conv1d(point_dimension, 64, 1)
        self.conv_2 = nn.Conv1d(64, 64, 1)
        self.conv_3 = nn.Conv1d(64, 64, 1)
        self.conv_4 = nn.Conv1d(64, 128, 1)
        self.conv_5 = nn.Conv1d(128, 1024, 1)

        self.bn_1 = nn.BatchNorm1d(64)
        self.bn_2 = nn.BatchNorm1d(64)
        self.bn_3 = nn.BatchNorm1d(64)
        self.bn_4 = nn.BatchNorm1d(128)
        self.bn_5 = nn.BatchNorm1d(1024)

    def forward(self, x):
        num_points = x.shape[1]

        input_transform = self.input_transform(x)

        x = torch.bmm(x, input_transform)
        x = x.transpose(2, 1)
        x = F.relu(self.bn_1(self.conv_1(x)))
        x = F.relu(self.bn_2(self.conv_2(x)))
        x = x.transpose(2, 1)

        feature_transform = self.feature_transform(x)

        x = torch.bmm(x, feature_transform)
        local_point_features = x

        x = x.transpose(2, 1)
        x = F.relu(self.bn_3(self.conv_3(x)))
        x = F.relu(self.bn_4(self.conv_4(x)))
        x = F.relu(self.bn_5(self.conv_5(x)))
        x = nn.MaxPool1d(num_points)(x)
        x = x.view(-1, 1024)

        if self.return_local_features:
            x = x.view(-1, 1024, 1).repeat(1, 1, num_points)
            return torch.cat([x.transpose(2, 1), local_point_features], 2), feature_transform
        else:
            return x, feature_transform

class FeaturesPointNet(nn.Module):

    def __init__(self, dropout, point_dimension):
        super(FeaturesPointNet, self).__init__()
        self.base_pointnet = BasePointNet(return_local_features=False, point_dimension=point_dimension)

        self.fc_1 = nn.Linear(1024, 512)
        self.fc_2 = nn.Linear(512, 256)
        #self.fc_3 = nn.Linear(256, num_classes)

        self.bn_1 = nn.BatchNorm1d(512)
        self.bn_2 = nn.BatchNorm1d(256)

        self.dropout_1 = nn.Dropout(dropout)

    def forward(self, x):
        x, feature_transform = self.base_pointnet(x)

        x = F.relu(self.bn_1(self.fc_1(x)))
        x = F.relu(self.bn_2(self.fc_2(x)))
        x = self.dropout_1(x)

        #return F.log_softmax(self.fc_3(x), dim=1), feature_transform
        return x, feature_transform


class Weights(nn.Module):

    
    def __init__(self, dropout=0.3, point_dimension=2):
        super(Weights, self).__init__()
        self.point1 = FeaturesPointNet(dropout, point_dimension)
        self.point2 = FeaturesPointNet(dropout, point_dimension)

        self.fc_1 = nn.Linear(518, 1)
        
        
    def forward(self, x):
        #print(x.shape)
        
        n = (x.shape[1] - 3)//2
        x1 = x[:, :n]
        x2 = x[:, n:2*n]
        v = x[:, 2*n]
        p1 = x[:, 2*n+1]
        p2 = x[:, 2*n+2]
        #print("x1.shape= ", x1.shape)
        #print("x2.shape= ", x2.shape)
        #print("v.shape = ", v.shape)
        
        assert 2*n+2==x.shape[1]-1
        
        #x1, x2, x3 = x
        #v = x3[:, 0]
        #p1 = x3[:, 1]
        #p2 = x3[:, 2]
        #print("x1.shape= ", x1.shape)
        #print("x2.shape= ", x2.shape)
        #print("v.shape = ", v.shape)
        
        #print("x1 = ", x1.shape)
        x1, feature_transform1 = self.point1(x1)
        x2, feature_transform2 = self.point2(x2)
        #print("feature1.shape = ", x1.shape)
        #print("feature2.shape = ", x2.shape)
        #print(feature_transform1.shape)
        
        x_tot = torch.cat((x1, x2), dim=1)
        #print("x_tot.shape = ", x_tot.shape)

        x = torch.cat((x_tot, v), dim=1)
        x = torch.cat((x, p1), dim=1)
        x = torch.cat((x, p2), dim=1)

        #print("x.shape = ", x.shape)
        #print("x = ", x)
        
        x = F.relu(self.fc_1(x))

        return x #, feature_transform


In [8]:
class FeaturesPointNet2(nn.Module):
    
    def __init__(self):
        super(FeaturesPointNet2, self).__init__()
        self.conv_1 = nn.Conv2d(1, 64, 3)
        self.conv_2 = nn.Conv2d(64, 160, 3, stride=2)
        #self.conv_3 = nn.Conv2d(128, 256, 3)
        #self.conv_4 = nn.Conv2d(256, 512, 3)
        #self.conv_5 = nn.Conv2d(512, 1024, 3, stride=2)
        self.linear_1 = nn.Linear(160, 1000)
        self.linear_2 = nn.Linear(1000, 1000)
        self.linear_3 = nn.Linear(1000, 100)
        self.maxpool_1 = nn.MaxPool2d(3, stride=2)
        self.maxpool_2 = nn.MaxPool2d(2, stride=2)
        self.maxpool_3 = nn.MaxPool2d(2, stride=2)

    def forward(self, x):

        #print("before transpose :", x.shape)
        x = x.transpose(3, 1)
        x = x.transpose(2, 3)
        #print("after transpose :", x.shape)
        x = F.relu(self.conv_1(x))
        #x = F.relu(self.conv_2(x))        
        #x = F.relu(self.conv_3(x))        
        #print("after conv3 :", x.shape)
        x = self.maxpool_1(x)
        #print("after pool1 :", x.shape)
        x = F.relu(self.conv_2(x))        
        #print("after covn4 :", x.shape)
        #x = F.relu(self.conv_5(x))
        #print("after covn5 :", x.shape)
        x = self.maxpool_2(x)
        #print("after pool2 :", x.shape)
        #x = F.relu(self.conv_3(x))
        #print("after conv3 :", x.shape)
        #x = self.maxpool_3(x)
        #print("after pool3 :", x.shape)
        x = x.transpose(3, 2)
        x = x.transpose(3, 1)
        #print("after retranspose :", x.shape)
        x = self.linear_1(x)
        #print("after lin1 :", x.shape)
        x = self.linear_2(x)
        #print("after lin2 :", x.shape)
        x = self.linear_3(x)
        #print("after lin3 :", x.shape)
        
        return x
    

class Weights2(nn.Module):

    
    def __init__(self):
        super(Weights2, self).__init__()
        self.point1 = FeaturesPointNet2()
        self.point2 = FeaturesPointNet2()

        self.fc_1 = nn.Linear(206, 1)
        
        
    def forward(self, x):
        #print("size = ...")
        x1, x2, x3 = x
        v = x3[:, 0]
        p1 = x3[:, 1]
        p2 = x3[:, 2]
        #print("x1.shape= ", x1.shape)
        #print("x2.shape= ", x2.shape)
        #print("v.shape = ", v.shape)
        
        #print("x1 :", x1.shape)
        x1 = self.point1(x1)
        #print("x2 :", x2.shape)
        x2 = self.point2(x2)
        #print("feature1.shape = ", x1.shape)
        #print("feature2.shape = ", x2.shape)
        #print(feature_transform1.shape)
        
        #print("x1=", x1.shape)
        #print("x2=", x2.shape)
        x1 = x1[:, 0, 0, :]
        x2 = x2[:, 0, 0, :]
        #print(x1.shape)
        #print(x2.shape)
        x_tot = torch.cat((x1, x2), dim=1)
        #print("x_tot.shape = ", x_tot.shape)
        #print("v.shape =", v.shape)

        x = torch.cat((x_tot, v), dim=1)
        x = torch.cat((x, p1), dim=1)
        x = torch.cat((x, p2), dim=1)

        #print("x.shape = ", x.shape)
        #print("x = ", x)
        
        x = F.relu(self.fc_1(x))

        return x