# Point transformer denoising implementation

## Using PointTransformer for noise segmentation

# Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
!pip install -qq open3d==0.12.0
!conda install -c conda-forge -y igl >/dev/null

In [None]:
CUDA, = !readlink /usr/local/cuda | sed -E 's/.*cuda-(\w+)\.(\w+)/cu\1\2/'
# !pip install -qq torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
# !pip install -qq torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
!pip install -qq torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
# !pip install -qq torch-geometric

In [None]:
!pip install -Uqq pytorch-lightning
!pip install -Uqq wandb

In [None]:
# !pip install -qq torchtyping

In [None]:
from pathlib import Path
from functools import partial

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.auto import tqdm

import torch
from torch import nn
from torch import optim

from torchvision import transforms

from torch_cluster import fps

import pytorch_lightning as pl

from sklearn.metrics import confusion_matrix

import open3d as o3d
import igl

In [None]:
pl.__version__

In [None]:
import os
from kaggle_secrets import UserSecretsClient

# Used in kaggle to easily get wandb up and running
os.environ["WANDB_API_KEY"] = UserSecretsClient().get_secret("wandb")

In [None]:
# TODO: check if there's anything else to seed
seed = 42
rng = np.random.default_rng(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# Data utilities

Sampler

Converters:
- PointCloudToTensor: point cloud to torch.Tensor

Transforms:
- PointCloudTranslate: translate point cloud by some value
- PointCloudRotate: can be used to rotate around any given axis
- PointCloudRotationalPerturbation
- PointCloudJitter: clipped Gaussian(0, sigma^2) noise
- PointCloudDropout
- PointNormalize
- **(NOT USED)** PointCloudScale: scale points by some value.

Other:
- PointShuffle

In [None]:
# Returns 3x3 rotation matrix that rotates by angle around axis
def get_rotation_matrix(angle, axis):
    # Unit vector in axis direction
    u = axis / np.linalg.norm(axis)

    cross_prod_mat = np.array(
        [[0.0, -u[2], u[1]],
        [u[2], 0.0, -u[0]],
        [-u[1], u[0], 0.0]]
    )

    cosval, sinval = np.cos(angle), np.sin(angle)
    rot_matrix = torch.from_numpy(
        cosval * np.eye(3)
        + sinval * cross_prod_mat
        + (1.0 - cosval) * np.outer(u, u)
    )

    return rot_matrix.float()

In [None]:
# Sampler
class PointSampler:
    def __init__(self, num_points: int):
        self.num_points = num_points

    def __call__(self, x: o3d.geometry.TriangleMesh) -> o3d.geometry.PointCloud:
        return x.sample_points_uniformly(number_of_points=self.num_points)


# Converters
class PointCloudToTensor:
    def __call__(self, x: o3d.geometry.PointCloud) -> torch.Tensor:
        return torch.from_numpy(np.asarray(x.points)).float()

# Transforms
# If not stated otherwise, the below transformations work with normals too
class PointCloudScale:
    def __init__(self, lo=0.8, hi=1.25):
        self.lo, self.hi = lo, hi

    def __call__(self, points):
        scale_by = rng.uniform(self.lo, self.hi)
        points[:, 0:3] *= scale_by
        return points

    
class PointCloudTranslate:
    def __init__(self, translate_range=0.1):
        self.translate_range = translate_range

    def __call__(self, points):
        translate_by = rng.uniform(-self.translate_range, self.translate_range)
        points[:, 0:3] += translate_by
        return points


class PointCloudRotate:
    def __init__(self, axis=np.array([0.0, 0.0, 1.0])):
        self.axis = axis

    def __call__(self, points):
        rotation_angle = rng.uniform() * 2 * np.pi
        rotation_matrix = get_rotation_matrix(rotation_angle, self.axis)

        has_normals = points.shape[1] > 3
        if not has_normals:
            return torch.matmul(points, rotation_matrix.t())
        else:
            pc_xyz = points[:, 0:3]
            pc_normals = points[:, 3:]
            points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
            points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())

            return points


class PointCloudRotationalPerturbation:
    def __init__(self, angle_sigma=0.06, angle_clip=0.18):
        self.angle_sigma, self.angle_clip = angle_sigma, angle_clip

    def _get_angles(self):
        angles = np.clip(
            self.angle_sigma * rng.randn(3), -self.angle_clip, self.angle_clip
        )

        return angles

    def __call__(self, points):
        angles = self._get_angles()
        Rx = get_rotation_matrix(angles[0], np.array([1.0, 0.0, 0.0]))
        Ry = get_rotation_matrix(angles[1], np.array([0.0, 1.0, 0.0]))
        Rz = get_rotation_matrix(angles[2], np.array([0.0, 0.0, 1.0]))

        # Combined rotation matrix
        rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx)

        has_normals = points.shape[1] > 3
        if not has_normals:
            return torch.matmul(points, rotation_matrix.t())
        else:
            pc_xyz = points[:, 0:3]
            pc_normals = points[:, 3:]
            points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
            points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())

            return points
        

class PointCloudJitter:
    def __init__(self, std=0.01, clip=0.05):
        self.std, self.clip = std, clip

    def __call__(self, points):
        jittered_data = (
            torch.normal(mean=0.0, std=self.std, size=(points.size(0), 3))
            .clamp_(-self.clip, self.clip)
        )
        points[:, 0:3] += jittered_data
        return points


# TODO: rewrite this in torch
class PointCloudDropout:
    def __init__(self, max_dropout_ratio=0.875):
        assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
        self.max_dropout_ratio = max_dropout_ratio

    def __call__(self, points):
        dropout_ratio = rng.random() * self.max_dropout_ratio  # 0~0.875
        drop_idx = np.where(rng.random((points.shape[0])) <= dropout_ratio)[0]
        if len(drop_idx) > 0:
            random_point_idx = rng.choice(list(set(range(points.shape[0])) - set(drop_idx.tolist())), size=1)[0]
            points[drop_idx, :] = points[random_point_idx].clone()  # set to the random point

        return points


class PointCloudNormalize:
    def __call__(self, x):
        x -= torch.mean(x, dim=0)
        return x / torch.max(x.norm(dim=1))

    
# Other
class PointCloudShuffle:
    def __call__(self, points):
        return points[torch.randperm(points.shape[0]), :]

In [None]:
# Definitions are here, but actual noise is generated later
NUM_POINTS = 2048
NOISE_RATIO = 0.3
NUM_OBJECT = int(np.floor(NUM_POINTS * (1 - NOISE_RATIO)))
NUM_NOISE = int(np.ceil(NUM_POINTS * NOISE_RATIO))
print(NUM_OBJECT, NUM_NOISE)
assert NUM_OBJECT + NUM_NOISE == NUM_POINTS

# - Uncomment the transformations that you want to use
# - Be careful when using PointCloudDropout, e.g. make 
#   sure to select a reasonable dropout ratio
# - Shuffle does not really do anything because PointTransformer
#   is invariant to reordering of the points
train_transforms = transforms.Compose([
    PointSampler(NUM_OBJECT),
    PointCloudToTensor(),
#     PointCloudTranslate(),
#     PointCloudRotate(),
#     PointCloudRotationalPerturbation(),
#     PointCloudJitter(),
#     PointCloudDropout(), # dropout
#     PointCloudShuffle(), # shuffling points
    PointCloudNormalize(),
])

valid_transforms = transforms.Compose([
    PointSampler(NUM_OBJECT),
    PointCloudToTensor(),
    PointCloudNormalize(),
])

plotting_transforms = transforms.Compose([
    PointSampler(NUM_OBJECT),
    PointCloudToTensor(),
    PointCloudNormalize(),
])

# Load the data

In [None]:
def add_jitter_noise(pc, std=0.01, clip=0.001):
    indices_for_noise = torch.randperm(len(pc))[:NUM_NOISE]
    noise = pc[indices_for_noise].clone()
    
    noise_jitter = torch.normal(mean=0.0, std=std, size=(noise.shape[0], 3))
    noise[..., :3] += noise_jitter
    jitter_labels = torch.ones(len(noise), dtype=torch.int32)
    too_close_indices = noise_jitter.norm(dim=1)
    too_close_indices = too_close_indices <= clip
    jitter_labels[too_close_indices] = 0

    y = torch.cat([
        torch.zeros(len(pc), dtype=torch.int32),
        jitter_labels
    ], dim=0)
    pc = torch.cat([pc, noise], dim=0) # concat noise with pc

    # Permute tensors
    permutation_idx = torch.randperm(pc.shape[0])
    pc = pc[permutation_idx, :]
    y = y[permutation_idx]
    
    # Renormalize
    pc = PointCloudNormalize()(pc)
    
    return pc, y

In [None]:
# Test

# tns = torch.randn((NUM_OBJECT, 3))
# pc, y = partial(add_jitter_noise, std=0.15, clip=0.07)(tns)
# pc.shape, y.shape, pc.mean(dim=0), pc.norm(dim=-1).max(dim=0)[0]

In [None]:
from typing import Tuple, List


class ModelNet40Dataset(torch.utils.data.Dataset):

    ext = ".off"

    def __init__(
        self,
        root: str,
        split: str,
        transforms: transforms.Compose,
        noise_function,
        subset_ratio=None,
    ) -> None:

        self.root = Path(root)
        self.split = split
        self.transforms = transforms
        self.noise_function = noise_function

        # Get class labels and mappings
        dirs = [item.stem for item in self.root.iterdir()]
        if subset_ratio:
            dirs = dirs[:int(len(dirs) * subset_ratio)]
        self.classes = sorted(dirs)
        print(self.classes)
        self.idx2class = dict(enumerate(self.classes))
        print(self.idx2class)
        
        self.class2idx = { c: i for i, c in self.idx2class.items() }
        print(self.class2idx)

        # List files and their labels
        self.meshes, self.labels = self.get_meshes_and_labels()

    def get_meshes_and_labels(self) -> Tuple[List[str], List[int]]:
        meshes, labels = [], []

        for i, c in tqdm(list(enumerate(self.classes))):
            path = self.root / c / self.split
            for f in tqdm(list(path.glob(f"*{self.ext}")), desc=c, leave=False):
                vertices, triangles = igl.read_triangle_mesh(str(f))
                mesh = o3d.geometry.TriangleMesh(
                    o3d.utility.Vector3dVector(vertices),
                    o3d.utility.Vector3iVector(triangles.astype(np.int32)),
                )
                meshes.append(mesh)
                labels.append(i)

        return meshes, labels

    def __len__(self) -> int:
        return len(self.meshes)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        mesh, category = self.meshes[idx], self.labels[idx]
        x = self.transforms(mesh)
        x, y = self.noise_function(x)

        return x, y, category

In [None]:
import multiprocessing as mp

class ModelNet40DataModule(pl.LightningDataModule):
    def __init__(
        self, data_dir="/kaggle/input/modelnet40/ModelNet40", batch_size=16,
        num_workers=mp.cpu_count(),
        noise_function=partial(add_jitter_noise, std=0.15, clip=0.07)
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.noise_function = noise_function

    def setup(self, num_points=1024, subset_ratio=None, stage=None):
        if stage == "fit" or stage is None:
            self.train_dset = ModelNet40Dataset(
                root=self.data_dir,
                split="train",
                transforms=train_transforms,
                noise_function=self.noise_function,
                subset_ratio=subset_ratio,
            )
            self.val_dset = ModelNet40Dataset(
                root=self.data_dir,
                split="test",
                transforms=valid_transforms,
                noise_function=self.noise_function,
                subset_ratio=subset_ratio,
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dset, batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

In [None]:
BS = 32
dm = ModelNet40DataModule(batch_size=BS, noise_function=partial(add_jitter_noise, std=0.25, clip=0.05))
# dm.setup(num_points=NUM_POINTS, subset_ratio=0.1) # use this for prototyping as it loads faster
dm.setup(num_points=NUM_POINTS)

In [None]:
len(dm.train_dset), len(dm.val_dset)

In [None]:
# dm.noise_function = dm.train_dset.noise_function = dm.val_dset.noise_function = partial(add_jitter_noise, std=0.125, clip=0.05)

## Plot some objects to see if everything is right

In [None]:
dm.train_dset.classes

In [None]:
import plotly.express as px


DISTINCT_LABELS = ["object", "noise"]

def pc_show(item, is_noise, category):
    x, y, z = [item[:, i] for i in range(3)]
    labels = [DISTINCT_LABELS[point] for point in is_noise.tolist()]

    df = pd.DataFrame(dict(
        x=x,
        y=y,
        z=z,
        is_noise=labels,
        size=[15] * len(labels),
    ))
    
    print(category)

    color_discrete_map = dict(zip(DISTINCT_LABELS, px.colors.sequential.Turbo))

    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="is_noise", size="size",
        opacity=0.0,
        size_max=15,
        color_discrete_map=color_discrete_map,
        category_orders=dict(is_noise=DISTINCT_LABELS)
    )
    fig.show()


In [None]:
DISTINCT_LABELS = ["object", "noise"]

# We defined this because the potly version does not always work
def pc_show_matplotlib(item, is_noise, category, with_noise=True):
    print(category)
    
    fig = plt.figure(figsize=(30, 30))
    ax = fig.add_subplot(projection="3d")
    
    item = item.numpy()
    is_noise = is_noise.numpy()

    markers = ["o"]
    if with_noise:
        markers.append("^")
    for i, m in enumerate(markers):
        indices = (is_noise == i).nonzero()[0]
        xs = item[indices, 0]
        ys = item[indices, 1]
        zs = item[indices, 2]
        ax.scatter(xs, ys, zs, marker=m)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    plt.show()


In [None]:
idx = rng.integers(len(dm.train_dset))
x, y, category = dm.train_dset[idx]
display(x.shape, y.shape, category, y.sum())
pc_show_matplotlib(x, y, dm.train_dset.classes[category], True)
pc_show_matplotlib(x[y == 1], y[y == 1], dm.train_dset.classes[category], True)

In [None]:
threshold = 0.05
noise = x[y == 1]
noise_labels = y[y == 1]
indices = (noise[:, None, :] - x[y == 0][None, ...]).norm(dim=-1).min(dim=-1)[0] <= threshold
print(indices.shape, indices.sum())
x1 = torch.cat([x[y == 0], noise[indices]], dim=0)
y1 = torch.cat([y[y == 0], noise_labels[indices]], dim=0)
pc_show(x1, y1, dm.train_dset.classes[category])

In [None]:
pc_show(x, y, dm.train_dset.classes[category])

In [None]:
x, y, cats = next(iter(dm.train_dataloader()))
x.shape, y.shape, len(cats)

# Model

In [None]:
def get_neighbours(features, idx):
    """
    Input:
        features: input points data, [B, N, C]
        idx: neighbour index data, [B, N, K]
    Return:
        new_points:, indexed points data, [B, N, K, C]
    """
    
    raw_size = idx.size()
    
    idx = idx.reshape(raw_size[0], -1)
    idx = idx[..., None]
    idx = idx.expand(-1, -1, features.shape[-1])
    
    res = features.gather(dim=1, index=idx)
    res = res.reshape(*raw_size, -1)
    
    return res

In [None]:
class TransitionDownBlock(nn.Module):

    def __init__(self, in_dims, out_dims, num_neighbours=16, sampling_ratio=0.25):
        super().__init__()
        
        self.num_neighbours = num_neighbours
        self.sampling_ratio = sampling_ratio
        
        self.mlp = nn.Sequential(
            nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False), # Should be the same as Linear since dims are transposed
            nn.BatchNorm1d(out_dims),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        
        # Flatten pos from [B, N(=num points), 3]
        # to [B * N, 3]
        tmp = pos.reshape((-1, pos.shape[-1]))
        
        # Create the tensor that will tell which point belongs to
        # which batch element
        batch = torch.arange(pos.shape[0]).cuda()
        batch = torch.repeat_interleave(batch, repeats=pos.shape[1], dim=0)

        # Get indices of sampled points in tmp
        indices = fps(tmp, batch, ratio=self.sampling_ratio, random_start=True)
        
        # Get pos_sampled from tmp and index,
        # shape is [B, int(N * self.sampling_ratio), 3]
        pos_sampled = tmp[indices].reshape((pos.shape[0], -1, pos.shape[-1]))        
        
        
        # KNN ------------------------------
        # Get vector lengths using 2-norm
        rel_dist = (pos_sampled[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)

        # Get indices of k-nearest neighbours
        num_points = x.shape[1]
        _, neighbour_indices = rel_dist.topk(min(num_points, self.num_neighbours), largest=False)
        
        
        # MLP -----------------------------
        # Transforms input features
        x = self.mlp(x.transpose(1, 2)).transpose(1, 2) # [B, N, out_dims]
        
        # Get only the neighbours
        x_sampled = get_neighbours(x, neighbour_indices) # [B, N_sampled, k, out_dims]

        
        # MAX POOLING ---------------------
        # Selects max value for each dimension over all neighbours
        x_sampled = torch.max(x_sampled, dim=2)[0] # # [B, N_sampled, out_dims]

        return x_sampled, pos_sampled

In [None]:
# Test

# features = torch.randn((16, 16, 256))
# pos = torch.randn((16, 16, 3))
# lateral_pos = torch.randn((16, 64, 3))

# rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
# weights, neighbour_indices = rel_dist.topk(3, largest=False)

# res = get_neighbours(features, neighbour_indices)
# res.shape

In [None]:
# Test

# features = torch.arange(2*8*2).reshape(2, 8, 2)
# features
# pos = torch.randn((2, 8, 3))
# lateral_pos = torch.randn((2, 32, 3))

# print(features[0])
# rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
# weights, neighbour_indices = rel_dist.topk(3, largest=False)
# print(neighbour_indices[0, :3])

# res = get_neighbours(features, neighbour_indices)
# res.shape, res[0, :3]

In [None]:
class TransitionUpBlock(nn.Module):

    def __init__(self, in_dims, out_dims):
        super().__init__()

        self.up_mlp = nn.Sequential(
            nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_dims),
            nn.ReLU()
        )
        self.lateral_mlp = nn.Sequential(
            nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_dims),
            nn.ReLU()
        )

    def forward(self, features, pos, lateral_features, lateral_pos):
        """
            features: (B, N, in_channels) torch.Tensor
            pos: (B, N, 3) torch.Tensor
            lateral_features: (B, M, out_channels) torch.Tensor
            lateral_pos: (B, M, 3) torch.Tensor
        Note that N is smaller than M because this module upsamples features.
        """
        
        features = self.up_mlp(features.transpose(1, 2)).transpose(1, 2)
        
        # Find three nearest neighbours of lateral_pos in pos
        rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
        weights, neighbour_indices = rel_dist.topk(3, largest=False)
        
        # Interpolation weights
        weights = 1.0 / (weights + 1e-8)
        weights = weights / torch.sum(weights, dim=2, keepdim=True) # [B, M, 3]
        
        # Get triplets of vectors to interpolate from
        interpolated_features = get_neighbours(features, neighbour_indices) # [B, M, 3, C]
        # Do interpolation using weights from above
        interpolated_features = torch.sum(interpolated_features * weights[..., None], dim=-2)
        
        lateral_features = self.lateral_mlp(lateral_features.transpose(1, 2)).transpose(1, 2)
        
        # Add interpolated features to features from before
        out = interpolated_features + lateral_features
        
        return out, lateral_pos

In [None]:
# Test

# features = torch.randn((16, 16, 256))
# pos = torch.randn((16, 16, 3))
# lateral_features = torch.randn((16, 64, 128))
# lateral_pos = torch.randn((16, 64, 3))

# up = TransitionUpBlock(in_channels=256, out_channels=128)

# with torch.no_grad():
#     feat, pos = up(features, pos, lateral_features, lateral_pos)
#     print(feat.shape)
#     print(pos.shape)

In [None]:
class PointTransformerLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_neighbors,
        pos_mlp_hidden_dim=64,
        attn_mlp_hidden_mult=4, # This comes from initial Transformer paper
        dropout=0,
    ):
        super().__init__()
        self.num_neighbors = num_neighbors

        self.to_queries = nn.Linear(dim, dim, bias=False) # phi
        self.to_keys = nn.Linear(dim, dim, bias=False) # psi
        self.to_values = nn.Linear(dim, dim, bias=False) # alpha

        # theta
        self.pos_mlp = nn.Sequential(
            nn.Linear(3, pos_mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(pos_mlp_hidden_dim, dim)
        )
        
        # gamma
        self.to_attn_weights = nn.Sequential(
            nn.Linear(dim, dim * attn_mlp_hidden_mult),
            nn.ReLU(),
            nn.Linear(dim * attn_mlp_hidden_mult, dim),
        )
        
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        num_points, num_neighbors = x.shape[1], self.num_neighbors
        
        # get nearest neighbour indices
        rel_dist = (pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
        _, neighbour_indices = rel_dist.topk(min(num_neighbors, num_points), largest=False)

        # get queries, keys, values,
        # immediately leaving only the neighbouring points for k and v
        q = self.to_queries(x)
        k = get_neighbours(self.to_keys(x), neighbour_indices)
        v = get_neighbours(self.to_values(x), neighbour_indices)

        # use subtraction relation between queries and keys
        qk_rel = q[:, :, None, :] - k

        # calculate position embeddings
        rel_pos_emb = self.pos_mlp(pos[:, :, None, :] - get_neighbours(pos, neighbour_indices))
        rel_pos_emb = self.drop(rel_pos_emb)
        
        # add relative positional embeddings to values
        v += rel_pos_emb

        # use attention weights mlp, making sure to add relative positional embedding first
        rel_pos_emb = self.to_attn_weights(qk_rel + rel_pos_emb)

        # attention weights
        rel_pos_emb = rel_pos_emb.softmax(dim=-2)

        # aggregate
        agg = torch.sum(torch.mul(rel_pos_emb, v), dim=-2)
        
        return agg

In [None]:
class PointTransformerBlock(nn.Module):
    def __init__(self, dim, hidden_dim, pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, dropouts=None, num_neighbours=None):
        super().__init__()
        
        self.fc_in = nn.Linear(dim, hidden_dim)
        self.point_transformer_layer = PointTransformerLayer(
            dim=hidden_dim,
            num_neighbors=num_neighbours,
            pos_mlp_hidden_dim=pos_mlp_hidden_dim,
            attn_mlp_hidden_mult=attn_mlp_hidden_mult, # This comes from initial Transformer paper
            dropout=dropouts[0] if dropouts else 0,
        )
        self.fc_out = nn.Linear(hidden_dim, dim)
        
        self.drop = nn.Dropout(dropouts[1] if dropouts else 0)
    
    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        out = self.fc_in(x)
        
        out = self.point_transformer_layer(out, pos)
        
        out = self.fc_out(out)
        
        out = self.drop(out)
        
        # Residual
        out = out + x
        
        return out

In [1]:
# Test

# from itertools import tee

# def pairwise(iterable):
#     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
#     a, b = tee(iterable)
#     next(b, None)
#     return list(zip(a, b))

In [None]:
class PointTransformerSegmentator(nn.Module):
    def __init__(self, layer_config, num_neighbours=16, sampling_ratio=0.25, dropouts=None):
        super().__init__()
        assert dropouts is None or len(dropouts) == 3
        
        in_dims, *inner_dims, out_dims = layer_config
        encoder_dims = pairwise(inner_dims)
        decoder_dims = pairwise(inner_dims[::-1])
        
        self.encoder_transitions = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_dims, inner_dims[0]),
                nn.ReLU(),
                nn.Linear(inner_dims[0], inner_dims[0]),
            ),
        ] + [
            TransitionDownBlock(
                in_dims=in_dims,
                out_dims=out_dims,
                num_neighbours=num_neighbours,
                sampling_ratio=sampling_ratio,
            )
            for in_dims, out_dims in encoder_dims
        ])
        self.encoder_transformers = nn.ModuleList([
            PointTransformerBlock(
                dim=dims,
                hidden_dim=dims,
                pos_mlp_hidden_dim=dims,
                attn_mlp_hidden_mult=4,
                num_neighbours=num_neighbours,
                dropouts=dropouts[:-1] if dropouts else None,
            )
            for dims in inner_dims
        ])
        
        self.decoder_transitions = nn.ModuleList([
            nn.Linear(inner_dims[-1], inner_dims[-1]),
        ] + [
            TransitionUpBlock(
                in_dims=in_dims,
                out_dims=out_dims,
            )
            for in_dims, out_dims in decoder_dims
        ])
        self.decoder_transformers = nn.ModuleList([
            PointTransformerBlock(
                dim=dims,
                hidden_dim=dims,
                pos_mlp_hidden_dim=dims,
                attn_mlp_hidden_mult=4,
                num_neighbours=num_neighbours,
                dropouts=dropouts[:-1] if dropouts else None,
            )
            for dims in inner_dims[::-1]
        ])
        
        self.drop = nn.Dropout(dropouts[-1] if dropouts else 0)
        
        self.fc_out = nn.Sequential(
            nn.Linear(inner_dims[0], inner_dims[0]),
            nn.ReLU(),
            nn.Linear(inner_dims[0], out_dims),
        )
    
    def forward(self, pos: torch.Tensor) -> torch.Tensor:
        lateral_features = []
        lateral_pos = []

        for i, (trans_down, transformer) in enumerate(zip(self.encoder_transitions, self.encoder_transformers)):
            if i == 0:
                features = trans_down(pos)
                features = transformer(features, pos)
            else:
                features, pos = trans_down(features, pos)
                features = transformer(features, pos)

            if i < len(self.encoder_transitions) - 1:
                lateral_features.append(features)
                lateral_pos.append(pos)

        for i, (trans_up, transformer) in enumerate(zip(self.decoder_transitions, self.decoder_transformers)):
            if i == 0:
                features = trans_up(features)
                features = transformer(features, pos)
            else:
                features, pos = trans_up(features, pos, lateral_features[-i], lateral_pos[-i])
                features = transformer(features, pos)

        out = self.drop(features)

        out = self.fc_out(out)
        
        return out

# Training

In [None]:
from dataclasses import dataclass

@dataclass
class CMCounts:
    tp: int = 0
    fp: int = 0
    fn: int = 0
    tn: int = 0
    
    # two tensors
    @classmethod
    def from_tensors(cls, target, preds):
        tp, fp, fn, tn = torch.dstack((
            preds & target > 0,
            preds > target,
            preds < target,
            preds | target == 0,
        )).sum((0, 1))
        return cls(*[x.item() for x in [tp, fp, fn, tn]])
        
        
    @property
    def f1(self):
        return self.tp / (self.tp + 0.5 * (self.fp + self.fn))
    
    @property
    def f2(self):
        return self.f_beta(2)
    
    def f_beta(self, beta):
        return (1 + beta ** 2) * self.tp / ((1 + beta ** 2) * self.tp + (beta ** 2) * self.fn + self.fp)
    
    @property
    def accuracy(self):
        return (self.tp + self.tn) / (self.tp + self.fp + self.fn + self.tn)
    
    def __add__(self, other):
        return CMCounts(
            tp=self.tp + other.tp,
            fp=self.fp + other.fp,
            fn=self.fn + other.fn,
            tn=self.tn + other.tn,
        )
    
    def __radd__(self, other):
        return self if other == 0 else self + other

In [None]:
class SegmentationTask(pl.LightningModule):
    def __init__(self, model, max_lr, epochs, steps_per_epoch, num_classes=2):
        super().__init__()
        self.save_hyperparameters("max_lr", "epochs")
        self.steps_per_epoch = steps_per_epoch
        self.model = model
        self.loss = nn.BCEWithLogitsLoss(reduction="mean")
        
    def forward(self, x):
        return self.model(x)


    def _shared_step(self, batch, prefix):
        x, y, _ = batch
        logits = self.model(x).squeeze()
        loss = self.loss(logits, y.float())
        
        y_preds = (logits >= 0.0).long() # since we do not call sigmoid
        
        cm_counts = CMCounts.from_tensors(y.flatten(), y_preds.flatten())

        self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True)

        return { "loss": loss, "cm_counts": cm_counts }
    
    # TODO: this is hard-coded, should be automated
    def on_train_start(self):
        self.logger.log_hyperparams({
            "bs": BS, "num_points": NUM_POINTS,
            "max_lr": self.hparams.max_lr, "epochs": self.hparams.epochs,
            "optimizer": "AdamW(wd=1e-2)", "scheduler": "OneCycleLR",
        })
    
    
    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")


    def training_epoch_end(self, outs):
        cm_count_total = sum(map(lambda x: x["cm_counts"], outs))
        self.log("train_f1", cm_count_total.f1)
        self.log("train_f2", cm_count_total.f2)


    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "valid")

    
    def validation_epoch_end(self, outs):
        cm_count_total = sum(map(lambda x: x["cm_counts"], outs))
        self.log("valid_f1", cm_count_total.f1)
        self.log("valid_f2", cm_count_total.f2)


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), weight_decay=1e-2)
        lr_scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.max_lr,
            steps_per_epoch=self.steps_per_epoch,
            epochs=self.hparams.epochs
        )
        lr_dict = { "scheduler": lr_scheduler, "interval": "step" }
        return [optimizer], [lr_dict]

In [None]:
layer_config = [3, 32, 64, 128, 256, 512, 1]

model = PointTransformerSegmentator(
    layer_config=layer_config,
    num_neighbours=16,
    sampling_ratio=0.25,
    dropouts=[0.0, 0.0, 0.0]
)

EPOCHS = 50
LR = 1e-3

segmentator = SegmentationTask(
    model,
    max_lr=LR,
    epochs=EPOCHS,
    steps_per_epoch=len(dm.train_dataloader()),
    num_classes=1
)

In [None]:
from datetime import datetime
from pytz import timezone

# Make sure to change wandb parameters or get rid of it altogether
wandb_logger = pl.loggers.WandbLogger(
    project="project-name",
    entity="company-name",
    name=str(datetime.now(tz=timezone("Continent/City")))
)
# Select path where you want to log to
trainer = pl.Trainer(gpus=1, max_epochs=EPOCHS, logger=[pl.loggers.CSVLogger("path/to/logs"), wandb_logger])

In [None]:
lr_finder = trainer.tuner.lr_find(segmentator, datamodule=dm, min_lr=1e-6, max_lr=1e-1)
fig = lr_finder.plot(suggest=True)
fig.show()

In [None]:
# classifier.hparams.lr = lr_finder.suggestion() # use if you want the lr to be selected automatically
lr_finder.suggestion()

In [None]:
trainer.fit(segmentator, datamodule=dm)

In [None]:
# Put weight save path here
trainer.save_checkpoint("path/to/weights.ckpt")

In [None]:
import wandb
wandb.finish() # Do not forget to finish wandb session

# Evaluation

In [None]:
def plot_metric(data, metric_label, is_log_scale=False, quiet=False):
    if not quiet:
        print(f"train_{metric_label}: {logs[f'train_{metric_label}'].dropna().values[-1]:.2f}")
        print(f"valid_{metric_label}: {logs[f'valid_{metric_label}'].dropna().values[-1]:.2f}")

    sns.lineplot(data=data, x="epoch", y=f"train_{metric_label}", color="red", label="train")
    sns.lineplot(data=data, x="epoch", y=f"valid_{metric_label}", color="blue", label="valid")
    if is_log_scale:
        plt.yscale("log")
    plt.title(f"{'log ' if is_log_scale else ''}{metric_label}")
    plt.legend()
    plt.show()

In [None]:
# Put path to logs here
# It only worked for us if we added default at the end,
# so make sure to check if that is the case for you
LOG_PATH = Path("path/to/logs/default")
log_file = sorted(list(LOG_PATH.glob("**/*.csv")), key=lambda x: int(x.parent.stem.split("_")[1]))[-1]
print(log_file)

logs = pd.read_csv(log_file)
display(logs)

plot_metric(logs, "loss")
plot_metric(logs, "loss", is_log_scale=True, quiet=True)
plot_metric(logs, "f1")
plot_metric(logs, "f2")

In [None]:
from collections import defaultdict

def evaluate(model, dl):

    model.cuda()
    model.eval()
    
    instances = dl.dataset.classes
    cm_count_dict = defaultdict(CMCounts)
    with torch.no_grad():
        for x, y, cats in tqdm(dl, total=len(dl)):
            x, y = x.cuda(), y.cuda()
            logits = model(x).squeeze()
            
            segm_preds = (logits >= 0.0).long() # since we do not call sigmoid
            
            cats = np.array(cats)
            for i, cat in enumerate(instances):
                mask = cats == i
                cat_ys = y[mask].flatten()
                cat_y_preds = segm_preds[mask].flatten()
                cm_count_new = CMCounts.from_tensors(y[mask].flatten(), segm_preds[mask].flatten())
                cm_count_dict[cat] += cm_count_new
                
    cm_count_total = sum([value for _, value in cm_count_dict.items()])
    print(f"F1: {cm_count_total.f1}")
    print(f"F2: {cm_count_total.f2}")
    print(f"Acc: {cm_count_total.accuracy}")
    
    dict_for_df = {
        "tp%": { key: value.tp / (value.tp + value.fp + value.fn + value.tn) for key, value in cm_count_dict.items() },
        "fp%": { key: value.fp / (value.tp + value.fp + value.fn + value.tn) for key, value in cm_count_dict.items() },
        "fn%": { key: value.fn / (value.tp + value.fp + value.fn + value.tn) for key, value in cm_count_dict.items() },
        "tn%": { key: value.tn / (value.tp + value.fp + value.fn + value.tn) for key, value in cm_count_dict.items() },
        "f1": { key: value.f1 for key, value in cm_count_dict.items() },
        "acc": { key: value.accuracy for key, value in cm_count_dict.items() },
    }
    df = pd.DataFrame(dict_for_df)
    display(df)
    ax = df.plot.bar(y="f1", rot=90)
    ax = df.plot.bar(y="acc", rot=90)

In [None]:
evaluate(segmentator, dm.train_dataloader())

In [None]:
evaluate(segmentator, dm.val_dataloader())

## Plotting

In [None]:
import plotly.express as px

def pc_show_ground_truth_vs_prediction(item, ground_truth, prediction, category):
    x, y, z = [item[:, i] for i in range(3)]
    labels = [DISTINCT_LABELS[point.item()] for point in ground_truth]
    labels_pred = [DISTINCT_LABELS[point.item()] for point in prediction]

    df = pd.DataFrame(dict(
        x=x,
        y=y,
        z=z,
        ground_truth=labels,
        predicted=labels_pred,
        size=[15] * len(labels),
    ))
    
    print(category)

    color_discrete_map = dict(zip(DISTINCT_LABELS, ["blue", "red"]))

    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="ground_truth", size="size",
        opacity=0.0,
        size_max=15,
        color_discrete_map=color_discrete_map,
        category_orders=dict(is_noise=DISTINCT_LABELS)
    )
    fig.show()

    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="predicted", size="size",
        opacity=0.0,
        size_max=15,
        color_discrete_map=color_discrete_map,
        category_orders=dict(is_noise=DISTINCT_LABELS)
    )
    fig.show()

In [None]:
def pc_show_error(item, ground_truth, prediction, category):
    x, y, z = [item[:, i] for i in range(3)]

    are_agreeing = ground_truth == prediction
    lbls = ["Agree, was object", "Agree, was noise", "Disagree, was object", "Disagree, was noise"]
    labels = []
    for i, point in enumerate(ground_truth):
        if are_agreeing[i]:
            labels.append("Agree, was object" if ground_truth[i] == 0 else "Agree, was noise")
        else:
            labels.append("Disagree, was object" if ground_truth[i] == 0 else "Disagree, was noise")

    df = pd.DataFrame(dict(
        x=x,
        y=y,
        z=z,
        predictions=labels,
        size=[15] * len(labels),
    ))
    
    print(category)

    color_discrete_map = dict(zip(DISTINCT_LABELS, ["green", "blue", "red", "yellow"]))

    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="predictions", size="size",
        opacity=0.0,
        size_max=15,
        color_discrete_map=color_discrete_map,
        category_orders=dict(is_noise=DISTINCT_LABELS)
    )
    fig.show()
    

def pc_show_error_matplotlib(item, ground_truth, prediction, category):
    print(category)
    
    fig = plt.figure(figsize=(30, 30))
    ax = fig.add_subplot(projection="3d")
    
    x, y, z = [item[:, i] for i in range(3)]
    
    are_agreeing = ground_truth == prediction
    lbls = ["Agree, was object", "Agree, was noise", "Disagree, was object", "Disagree, was noise"]
    labels = []
    for i, point in enumerate(ground_truth):
        if are_agreeing[i]:
            labels.append("Agree, was object" if ground_truth[i] == 0 else "Agree, was noise")
        else:
            labels.append("Disagree, was object" if ground_truth[i] == 0 else "Disagree, was noise")

    labels = np.array(labels)
    colours = ["green", "blue", "red", "orange"]
    for i, c in enumerate(colours):
        indices = (labels == lbls[i]).nonzero()[0]
        xs = item[indices, 0]
        ys = item[indices, 1]
        zs = item[indices, 2]
        ax.scatter(xs, ys, zs, color=c, s=50, label=lbls[i])
    
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.legend()

    plt.show()

In [None]:
instances = dm.val_dset.classes; instances

In [None]:
ins_identifier = instances.index("airplane")
models = [obj for obj in dm.val_dset if obj[2] == ins_identifier]
len(models)

In [None]:
import random
obj = random.choice(models)

segmentator.eval()
segmentator.cuda()
with torch.no_grad():
    prediction = segmentator(obj[0][None, ...].cuda()).squeeze() >= 0

cat = dm.val_dset.classes[obj[2]]
pc_show_matplotlib(obj[0], obj[1], cat, with_noise=True) # ground truth
pc_show_matplotlib(obj[0], prediction.cpu(), cat, with_noise=True) # prediction
pc_show_error_matplotlib(obj[0], obj[1], prediction.cpu(), cat)

In [None]:
pc_show_ground_truth_vs_prediction(obj[0], obj[1], prediction, cat)
pc_show_error(obj[0], obj[1], prediction.cpu(), cat)