<a href="https://colab.research.google.com/github/amitshmidov/geometric_learning_final_project_torch/blob/main/geometric_learning_final_project_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install trimesh

Collecting trimesh
[?25l  Downloading https://files.pythonhosted.org/packages/8e/eb/28cd472f53790ff5ba5d8e9138cc5f8cc0a5f761d1114454ff0d37006b7f/trimesh-3.9.8-py3-none-any.whl (629kB)
[K     |▌                               | 10kB 15.4MB/s eta 0:00:01[K     |█                               | 20kB 22.3MB/s eta 0:00:01[K     |█▋                              | 30kB 27.2MB/s eta 0:00:01[K     |██                              | 40kB 21.6MB/s eta 0:00:01[K     |██▋                             | 51kB 13.2MB/s eta 0:00:01[K     |███▏                            | 61kB 14.5MB/s eta 0:00:01[K     |███▋                            | 71kB 10.7MB/s eta 0:00:01[K     |████▏                           | 81kB 11.6MB/s eta 0:00:01[K     |████▊                           | 92kB 12.6MB/s eta 0:00:01[K     |█████▏                          | 102kB 12.7MB/s eta 0:00:01[K     |█████▊                          | 112kB 12.7MB/s eta 0:00:01[K     |██████▎                         | 122kB 12.7

In [4]:
import numpy as np
import pandas as pd
import trimesh
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad
import itertools

from trimesh.sample import sample_surface, sample_surface_even
from typing import List

In [5]:
ACTIVATIONS = dict(
    relu=nn.ReLU, 
    lrelu=nn.LeakyReLU
)

CONVS = dict(
    one_dims=nn.Conv1d,
    two_dims=nn.Conv2d
)

BATCHNORMS = dict(
    one_dims=nn.BatchNorm1d,
    two_dims=nn.BatchNorm2d
)


def conv_block(
        channels: tuple,
        dims: bool = False,
        activation: str = 'relu',
        activation_params = dict()
) -> nn.Sequential:
    """
    Return a convolutional Sequential block .
    :param channels: channels of each layer.
    :param dims: dims=False for 1d operations, True for 2 operations.
    :param activation: type of activation.
    :param activation_params: specific activation params.
    :return: nn.Sequential object.
    """
    layers = []
    mode = 'one_dims' if not dims else 'two_dims'
    act = ACTIVATIONS[activation]
    conv = CONVS[mode]
    bn = BATCHNORMS[mode]
    for i in range(len(channels) - 1):
        layers.append(conv(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=1))
        layers.append(bn(num_features=channels[i + 1]))
        layers.append(act(**activation_params))
    return nn.Sequential(*layers)


def linear_block(
        features: tuple,
        out_features: int,
        dropout: float = 0,
        activation: str = 'relu',
        activation_params = dict()
) -> nn.Sequential:
    """
    Return a linear Sequential block .
    :param features: features of each layer.
    :param out_features: output features of the block.
    :param dropout: add dropout layer after each linear block (0 to disable).
    :param activation: type of activation.
    :param activation_params: specific activation params.
    :return: nn.Sequential object.
    """
    assert 0 <= dropout < 1
    
    act = ACTIVATIONS[activation]
    layers = []
    for i in range(len(features) - 1):
        layers.append(nn.Linear(in_features=features[i], out_features=features[i + 1]))
        layers.append(nn.BatchNorm1d(num_features=features[i + 1]))
        layers.append(act(**activation_params))
        if dropout > 0:
            layers.append(nn.Dropout(p=dropout))
    layers.append(nn.Linear(in_features=features[-1], out_features=out_features))
    return nn.Sequential(*layers)

In [6]:
def torch_(x):
    """
    Converts torch object to device and dtype.
    :param x: any torch object.
    :return: x in the current available device and dtype.
    """
    return x.to(torch_configs['dev']).type(torch_configs['type'])

In [7]:
class TNet(nn.Module):

    def __init__(
            self,
            num_points: int = 2000,
            num_features: int = 3,
            encoder: tuple = (64, 128, 1024),
            decoder: tuple = (1024, 512, 256),
            activation: str = 'relu',
            activation_params = dict()
    ):
        """
        :param num_points: number of points of each shape.
        :param num_features: number of features of each point.
        """
        super().__init__()

        assert encoder[-1] == decoder[0]

        self.num_points = num_points
        self.dims = num_features

        self.identity = grad.Variable(torch_(torch.eye(self.dims, requires_grad=True).view(-1)))

        channels = (num_features,) + encoder
        self.encoder = conv_block(channels, activation=activation, activation_params=activation_params)

        self.decoder = linear_block(features=decoder, out_features=num_features ** 2, dropout=0,
                                   activation=activation, activation_params=activation_params)

    def forward(self, x):
        x = self.encoder(x)
        x = F.max_pool1d(x, kernel_size=self.num_points).squeeze(2)
        x = self.decoder(x)
        x += self.identity
        x = x.view(-1, self.dims, self.dims)
        return x

In [8]:
class Momenet(nn.Module):

    def __init__(
            self,
            num_features: int,
            out_classes: int,
            num_points: int,
            lifting: List[nn.Module],
            hiddent: int = 12,
            hidden1: tuple = (64,),
            hidden2: tuple = (64, 64, 128, 1024),
            hidden3: tuple = (1024, 512, 256),
            activation: str = 'relu',
            activation_params = dict()
    ):
        """
        :param num_features: #features of each point.
        :param out_classes: #classes.
        :param num_points: #points in each shape.
        :param hiddent: the dimension of the feature transform.
        :param hidden1: channels of the 1st conv block (called mlp1 in paper).
        :param hidden2: channels of the 2nd conv block (called mlp2 in paper).
        :param hidden3: features of the classifier (called mlp in paper).
        :param activation: type of activation.
        :param activation_params: specific activation params.
        """
        super().__init__()

        self.spatial_transform = TNet(num_points, num_features, activation=activation, activation_params=activation_params)

        self.lifts = lifting

        channels = (hiddent,) + hidden1
        self.mlp1 = conv_block(channels, dims=True,
                               activation=activation, activation_params=activation_params)

        channels = (hidden1[-1],) + hidden2
        self.mlp2 = conv_block(channels, dims=False,
                               activation=activation, activation_params=activation_params)

        self.mlp3 = linear_block(features=hidden3, out_features=out_classes, dropout=0.4,
                                activation=activation, activation_params=activation_params)

        self.logsoftmax = nn.LogSoftmax(dim=0)

    def forward(self, x):
        """
        Feed-forward x in the network.
        :param x: A tensor with batch data.
        :return:
        """
        num_points = x.shape[-1]

        t = self.spatial_transform(x)
        x = torch.bmm(t, x)

        for lift in self.lifts:
            with torch.no_grad():
                x = lift(x)

        x = self.mlp1(x)

        x = x.max(dim=3).values

        x = self.mlp2(x)

        x = F.max_pool1d(x, num_points).squeeze(2)

        x = self.mlp3(x)

        x = self.logsoftmax(x)

        return x, t