# Point transformer denoising benchmark implementation

## Evaluating PointTransformer on a PointCleanNet benchmark

# Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
!pip install -qq open3d==0.12.0

In [None]:
CUDA, = !readlink /usr/local/cuda | sed -E 's/.*cuda-(\w+)\.(\w+)/cu\1\2/'
# !pip install -qq torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
# !pip install -qq torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
!pip install -qq torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+{CUDA}.html
# !pip install -qq torch-geometric

In [None]:
!pip install -Uqq pytorch-lightning
# !pip install -Uqq wandb

In [None]:
# !pip install -qq torchtyping

In [None]:
from pathlib import Path
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.auto import tqdm

import torch
from torch import nn
from torch import optim

from torchvision import transforms

from torch_cluster import fps

import pytorch_lightning as pl

import open3d as o3d

In [None]:
pl.__version__

In [None]:
# TODO: check if there's anything else to seed
seed = 42
rng = np.random.default_rng(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# Data utilities

In [None]:
# Put path to data here
PATH = Path("path/to/pointcleannetoutlierstestset/pointCleanNetOutliersTestSet/")

ground_truth_filenames = glob.glob(str(PATH) + "/*.outliers")
ground_truth_filenames = list(map(lambda x: Path(x), ground_truth_filenames))
len(ground_truth_filenames)

Sampler

Converters:
- PointCloudToTensor: point cloud to torch.Tensor

Transforms:
- PointCloudTranslate: translate point cloud by some value
- PointCloudRotate: can be used to rotate around any given axis
- PointCloudRotationalPerturbation
- PointCloudJitter: clipped Gaussian(0, sigma^2) noise
- PointCloudDropout
- PointNormalize
- **(NOT USED)** PointCloudScale: scale points by some value.

Other:
- PointShuffle

In [None]:
# Returns 3x3 rotation matrix that rotates by angle around axis
def get_rotation_matrix(angle, axis):
    # Unit vector in axis direction
    u = axis / np.linalg.norm(axis)

    cross_prod_mat = np.array(
        [[0.0, -u[2], u[1]],
        [u[2], 0.0, -u[0]],
        [-u[1], u[0], 0.0]]
    )

    cosval, sinval = np.cos(angle), np.sin(angle)
    rot_matrix = torch.from_numpy(
        cosval * np.eye(3)
        + sinval * cross_prod_mat
        + (1.0 - cosval) * np.outer(u, u)
    )

    return rot_matrix.float()

In [None]:
# Converters
class PointCloudToTensor:
    def __call__(self, x):
        return torch.from_numpy(x).float()


# If not stated otherwise, the below transformations work with normals too
# Transforms
class PointCloudScale:
    def __init__(self, lo=0.8, hi=1.25):
        self.lo, self.hi = lo, hi

    def __call__(self, points):
        scale_by = rng.uniform(self.lo, self.hi)
        points[:, 0:3] *= scale_by
        return points

    
class PointCloudTranslate:
    def __init__(self, translate_range=0.1):
        self.translate_range = translate_range

    def __call__(self, points):
        translate_by = rng.uniform(-self.translate_range, self.translate_range)
        points[:, 0:3] += translate_by
        return points


class PointCloudRotate:
    def __init__(self, axis=np.array([0.0, 0.0, 1.0])):
        self.axis = axis

    def __call__(self, points):
        rotation_angle = rng.uniform() * 2 * np.pi
        rotation_matrix = get_rotation_matrix(rotation_angle, self.axis)

        has_normals = points.shape[1] > 3
        if not has_normals:
            return torch.matmul(points, rotation_matrix.t())
        else:
            pc_xyz = points[:, 0:3]
            pc_normals = points[:, 3:]
            points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
            points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())

            return points


class PointCloudRotationalPerturbation:
    def __init__(self, angle_sigma=0.06, angle_clip=0.18):
        self.angle_sigma, self.angle_clip = angle_sigma, angle_clip

    def _get_angles(self):
        angles = np.clip(
            self.angle_sigma * rng.randn(3), -self.angle_clip, self.angle_clip
        )

        return angles

    def __call__(self, points):
        angles = self._get_angles()
        Rx = get_rotation_matrix(angles[0], np.array([1.0, 0.0, 0.0]))
        Ry = get_rotation_matrix(angles[1], np.array([0.0, 1.0, 0.0]))
        Rz = get_rotation_matrix(angles[2], np.array([0.0, 0.0, 1.0]))

        # Combined rotation matrix
        rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx)

        has_normals = points.shape[1] > 3
        if not has_normals:
            return torch.matmul(points, rotation_matrix.t())
        else:
            pc_xyz = points[:, 0:3]
            pc_normals = points[:, 3:]
            points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
            points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())

            return points
        

class PointCloudJitter:
    def __init__(self, std=0.01, clip=0.05):
        self.std, self.clip = std, clip

    def __call__(self, points):
        jittered_data = (
            torch.normal(mean=0.0, std=self.std, size=(points.size(0), 3))
            .clamp_(-self.clip, self.clip)
        )
        points[:, 0:3] += jittered_data
        return points


# TODO: rewrite this in torch
class PointCloudDropout:
    def __init__(self, max_dropout_ratio=0.875):
        assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
        self.max_dropout_ratio = max_dropout_ratio

    def __call__(self, points):
        dropout_ratio = rng.random() * self.max_dropout_ratio  # 0~0.875
        drop_idx = np.where(rng.random((points.shape[0])) <= dropout_ratio)[0]
        if len(drop_idx) > 0:
            random_point_idx = rng.choice(list(set(range(points.shape[0])) - set(drop_idx.tolist())), size=1)[0]
            points[drop_idx, :] = points[random_point_idx].clone()  # set to the random point

        return points


class PointCloudNormalize:
    def __call__(self, x):
        x -= torch.mean(x, dim=0)
        return x / torch.max(x.norm(dim=1))

    
# Other
class PointCloudShuffle:
    def __call__(self, points):
        return points[torch.randperm(points.shape[0]), :]

In [None]:
NUM_POINTS = 2048

# - Uncomment the transformations that you want to use
# - Be careful when using PointCloudDropout, e.g. make 
#   sure to select a reasonable dropout ratio
# - Shuffle does not really do anything because PointTransformer
#   is invariant to reordering of the points
train_transforms = transforms.Compose([
    PointCloudToTensor(),
#     PointCloudTranslate(),
#     PointCloudRotate(),
#     PointCloudRotationalPerturbation(),
#     PointCloudJitter(),
#     PointCloudDropout(), # dropout
#     PointCloudShuffle(), # shuffling points
    PointCloudNormalize(),
])

valid_transforms = transforms.Compose([
    PointCloudToTensor(),
    PointCloudNormalize(),
])

# Load the data

In [None]:
import plotly.express as px


DISTINCT_LABELS = ["object", "noise"]

def pc_show(item, is_noise, category):
    x, y, z = [item[:, i] for i in range(3)]
    labels = [DISTINCT_LABELS[point] for point in is_noise]
    print(labels[:10])

    df = pd.DataFrame(dict(
        x=x,
        y=z, # axes are somehow mixed up, so we have to swap these two
        z=y,
        is_noise=labels,
        size=[2] * len(labels),
    ))
    
    print(category)

    color_discrete_map = dict(zip(DISTINCT_LABELS, px.colors.sequential.Turbo))

    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="is_noise", size="size",
        opacity=0.0,
        size_max=2,
        color_discrete_map=color_discrete_map,
        category_orders=dict(is_noise=DISTINCT_LABELS)
    )
    fig.show()

In [None]:
# Test

# NUM_POINTS = 2048

# choice = rng.choice(pts.shape[0], NUM_POINTS, replace=pts.shape[0] < NUM_POINTS)

# pc_show(pts[choice], labels[choice], fn.stem)

In [None]:
DISTINCT_LABELS = ["object", "noise"]

def pc_show_matplotlib(item, is_noise, category, with_noise=True):
    fig = plt.figure(figsize=(30, 30))
    ax = fig.add_subplot(projection="3d")

    markers = ["o"]
    if with_noise:
        markers.append("^")
    for i, m in enumerate(markers):
        indices = (is_noise == i).nonzero()[0]
        xs = item[indices, 0]
        ys = item[indices, 2]
        zs = item[indices, 1]
        ax.scatter(-xs, ys, zs, marker=m)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    plt.show()

In [None]:
# Test

# choice = rng.choice(xyz.shape[0], NUM_POINTS, replace=xyz.shape[0] < NUM_POINTS)

# pc_show_matplotlib(xyz, labels, fn.stem, False)
# pc_show_matplotlib(xyz[choice], labels[choice], fn.stem, False)
# pc_show_matplotlib(xyz[choice], labels[choice], fn.stem, True)

In [None]:
def split_object(pc, labels):
    pts = np.copy(pc)
    labels = np.copy(labels)

    split_pcs, split_labels = [], []
    for i in range(1, pts.shape[0] // NUM_POINTS + 1):
        choice = rng.choice(pts.shape[0], NUM_POINTS, replace=pts.shape[0] < NUM_POINTS)
        
        new_pc = pts[choice]
        new_labels = labels[choice]
        split_pcs.append(new_pc)
        split_labels.append(new_labels)

        mask = np.ones(pts.shape[0], dtype=bool)
        mask[choice] = False
        pts = pts[mask]
        labels = labels[mask]
    
    return split_pcs, split_labels

# Test

# split_pcs, split_labels = split_object(xyz, labels)
# len(split_pcs), len(split_labels), xyz.shape, labels.shape

In [None]:
# Test

# import random

# to_plot = []
# for i in range(10):
#     obj = random.choice(list(zip(split_pcs, split_labels)))
#     to_plot.append(obj)

# for item in to_plot:
#     pc_show_matplotlib(item[0], item[1], fn.stem, True)

In [None]:
class PointCleanNetDataset(torch.utils.data.Dataset):

    def __init__(self, ground_truth_filenames, transforms, subset_ratio=None):

        self.ground_truth_filenames = ground_truth_filenames
        if subset_ratio is not None:
            self.ground_truth_filenames = self.ground_truth_filenames[:int(len(self.ground_truth_filenames) * subset_ratio)]

        self.transforms = transforms

        self.classes = { 0: "object", 1: "noise" }

        # Prepare data
        self._prepare_pcs_and_labels()
        print(len(self.pcs))
        print(len(self.labels))
        print(len(self.categories))
        print(list(self.full_objects.keys())[:10])

    def _prepare_pcs_and_labels(self):
        self.pcs, self.labels, self.categories = [], [], []
        self.full_objects = {}

        for gr_truth_fn in tqdm(self.ground_truth_filenames):
            category = gr_truth_fn.stem
            pc = o3d.io.read_point_cloud(str(gr_truth_fn)[:-9] + ".xyz")
            pc = np.asarray(pc.points)
            segm_labels = np.loadtxt(str(gr_truth_fn)).astype(np.int32)
            
            self.full_objects[category] = { "pc": pc, "labels": segm_labels }
            
            split_pcs, split_labels = split_object(pc, segm_labels)
            self.pcs.extend(split_pcs)
            self.labels.extend(split_labels)
            self.categories.extend([category] * len(split_pcs))


    def __len__(self):
        return len(self.pcs)

    def __getitem__(self, idx):
        pc = self.pcs[idx]
        x = self.transforms(pc)
        
        y = self.labels[idx]
        category = self.categories[idx]

        return x, y, category

In [None]:
# Test

# ds = PointCleanNetDataset(ground_truth_filenames, None)

In [None]:
import multiprocessing as mp

class PointCleanNetDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=16, num_workers=mp.cpu_count()):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, subset_ratio=None, stage=None):
        if stage == "test" or stage is None:
            self.test_dset = PointCleanNetDataset(ground_truth_filenames, valid_transforms, subset_ratio)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

In [None]:
BS = 32
dm = PointCleanNetDataModule(batch_size=BS)
dm.setup()
# dm.setup(subset_ratio=0.1) # use this for prototyping as it loads faster

In [None]:
len(dm.test_dset)

# Model

The essence of this notebook is here

In [None]:
def get_neighbours(features, idx):
    """
    Input:
        features: input points data, [B, N, C]
        idx: neighbour index data, [B, N, K]
    Return:
        new_points:, indexed points data, [B, N, K, C]
    """
    
    raw_size = idx.size()
    
    idx = idx.reshape(raw_size[0], -1)
    idx = idx[..., None]
    idx = idx.expand(-1, -1, features.shape[-1])
    
    res = features.gather(dim=1, index=idx)
    res = res.reshape(*raw_size, -1)
    
    return res

In [None]:
class TransitionDownBlock(nn.Module):

    def __init__(self, in_dims, out_dims, num_neighbours=16, sampling_ratio=0.25):
        super().__init__()
        
        self.num_neighbours = num_neighbours
        self.sampling_ratio = sampling_ratio
        
        self.mlp = nn.Sequential(
            nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False), # Should be the same as Linear since dims are transposed
            nn.BatchNorm1d(out_dims),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        
        # Flatten pos from [B, N(=num points), 3]
        # to [B * N, 3]
        tmp = pos.reshape((-1, pos.shape[-1]))
        
        # Create the tensor that will tell which point belongs to
        # which batch element
#         print(pos.shape)
        batch = torch.arange(pos.shape[0]).cuda()
        batch = torch.repeat_interleave(batch, repeats=pos.shape[1], dim=0)

        # Get indices of sampled points in tmp
        indices = fps(tmp, batch, ratio=self.sampling_ratio, random_start=True)
        
        # Get pos_sampled from tmp and index,
        # shape is [B, int(N * self.sampling_ratio), 3]
        pos_sampled = tmp[indices].reshape((pos.shape[0], -1, pos.shape[-1]))        
        
        
        # KNN ------------------------------
        # Get vector lengths using 2-norm
        rel_dist = (pos_sampled[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)

        # Get indices of k-nearest neighbours
        num_points = x.shape[1]
        _, neighbour_indices = rel_dist.topk(min(num_points, self.num_neighbours), largest=False)
        
        
        # MLP -----------------------------
        # Transforms input features
        x = self.mlp(x.transpose(1, 2)).transpose(1, 2) # [B, N, out_dims]
        
        # Get only the neighbours
        x_sampled = get_neighbours(x, neighbour_indices) # [B, N_sampled, k, out_dims]

        
        # MAX POOLING ---------------------
        # Selects max value for each dimension over all neighbours
        x_sampled = torch.max(x_sampled, dim=2)[0] # # [B, N_sampled, out_dims]

        return x_sampled, pos_sampled

In [None]:
# Test

# features = torch.randn((16, 16, 256))
# pos = torch.randn((16, 16, 3))
# lateral_pos = torch.randn((16, 64, 3))

# rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
# weights, neighbour_indices = rel_dist.topk(3, largest=False)

# res = get_neighbours(features, neighbour_indices)
# res.shape

In [None]:
# Test

# features = torch.arange(2*8*2).reshape(2, 8, 2)
# features
# pos = torch.randn((2, 8, 3))
# lateral_pos = torch.randn((2, 32, 3))

# print(features[0])
# rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
# weights, neighbour_indices = rel_dist.topk(3, largest=False)
# print(neighbour_indices[0, :3])

# res = get_neighbours(features, neighbour_indices)
# res.shape, res[0, :3]

In [None]:
class TransitionUpBlock(nn.Module):

    def __init__(self, in_dims, out_dims):
        super().__init__()

        self.up_mlp = nn.Sequential(
            nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_dims),
            nn.ReLU()
        )
        self.lateral_mlp = nn.Sequential(
            nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_dims),
            nn.ReLU()
        )

    def forward(self, features, pos, lateral_features, lateral_pos):
        """
            features: (B, N, in_channels) torch.Tensor
            pos: (B, N, 3) torch.Tensor
            lateral_features: (B, M, out_channels) torch.Tensor
            lateral_pos: (B, M, 3) torch.Tensor
        Note that N is smaller than M because this module upsamples features.
        """
        
        features = self.up_mlp(features.transpose(1, 2)).transpose(1, 2)
        
        # Find three nearest neighbours of lateral_pos in pos
        rel_dist = (lateral_pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
        weights, neighbour_indices = rel_dist.topk(3, largest=False)
        
        # Interpolation weights
        weights = 1.0 / (weights + 1e-8)
        weights = weights / torch.sum(weights, dim=2, keepdim=True) # [B, M, 3]
        
        # Get triplets of vectors to interpolate from
        interpolated_features = get_neighbours(features, neighbour_indices) # [B, M, 3, C]
        # Do interpolation using weights from above
        interpolated_features = torch.sum(interpolated_features * weights[..., None], dim=-2)
        
        lateral_features = self.lateral_mlp(lateral_features.transpose(1, 2)).transpose(1, 2)
        
        # Add interpolated features to features from before
        out = interpolated_features + lateral_features
        
        return out, lateral_pos


In [None]:
# Test

# features = torch.randn((16, 16, 256))
# pos = torch.randn((16, 16, 3))
# lateral_features = torch.randn((16, 64, 128))
# lateral_pos = torch.randn((16, 64, 3))

# up = TransitionUpBlock(in_channels=256, out_channels=128)

# with torch.no_grad():
#     feat, pos = up(features, pos, lateral_features, lateral_pos)
#     print(feat.shape)
#     print(pos.shape)

In [None]:
class PointTransformerLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_neighbors,
        pos_mlp_hidden_dim=64,
        attn_mlp_hidden_mult=4, # This comes from initial Transformer paper
        dropout=0,
    ):
        super().__init__()
        self.num_neighbors = num_neighbors

        self.to_queries = nn.Linear(dim, dim, bias=False) # phi
        self.to_keys = nn.Linear(dim, dim, bias=False) # psi
        self.to_values = nn.Linear(dim, dim, bias=False) # alpha

        # theta
        self.pos_mlp = nn.Sequential(
            nn.Linear(3, pos_mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(pos_mlp_hidden_dim, dim)
        )
        
        # gamma
        self.to_attn_weights = nn.Sequential(
            nn.Linear(dim, dim * attn_mlp_hidden_mult),
            nn.ReLU(),
            nn.Linear(dim * attn_mlp_hidden_mult, dim),
        )
        
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        num_points, num_neighbors = x.shape[1], self.num_neighbors
        
        # get nearest neighbour indices
        rel_dist = (pos[:, :, None, :] - pos[:, None, :, :]).norm(dim=-1)
        _, neighbour_indices = rel_dist.topk(min(num_neighbors, num_points), largest=False)

        # get queries, keys, values,
        # immediately leaving only the neighbouring points for k and v
        q = self.to_queries(x)
        k = get_neighbours(self.to_keys(x), neighbour_indices)
        v = get_neighbours(self.to_values(x), neighbour_indices)

        # use subtraction relation between queries and keys
        qk_rel = q[:, :, None, :] - k

        # calculate position embeddings
        rel_pos_emb = self.pos_mlp(pos[:, :, None, :] - get_neighbours(pos, neighbour_indices))
        rel_pos_emb = self.drop(rel_pos_emb)
        
        # add relative positional embeddings to values
        v += rel_pos_emb

        # use attention weights mlp, making sure to add relative positional embedding first
        rel_pos_emb = self.to_attn_weights(qk_rel + rel_pos_emb)

        # attention weights
        rel_pos_emb = rel_pos_emb.softmax(dim=-2)

        # aggregate
        agg = torch.sum(torch.mul(rel_pos_emb, v), dim=-2)
        
        return agg

In [None]:
class PointTransformerBlock(nn.Module):
    def __init__(self, dim, hidden_dim, pos_mlp_hidden_dim=64, attn_mlp_hidden_mult=4, dropouts=None, num_neighbours=None):
        super().__init__()
        
        self.fc_in = nn.Linear(dim, hidden_dim)
        self.point_transformer_layer = PointTransformerLayer(
            dim=hidden_dim,
            num_neighbors=num_neighbours,
            pos_mlp_hidden_dim=pos_mlp_hidden_dim,
            attn_mlp_hidden_mult=attn_mlp_hidden_mult, # This comes from initial Transformer paper
            dropout=dropouts[0] if dropouts else 0,
        )
        self.fc_out = nn.Linear(hidden_dim, dim)
        
        self.drop = nn.Dropout(dropouts[1] if dropouts else 0)
    
    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        out = self.fc_in(x)
        
        out = self.point_transformer_layer(out, pos)
        
        out = self.fc_out(out)
        
        out = self.drop(out)
        
        # Residual
        out = out + x
        
        return out

In [None]:
# Test

# from itertools import tee

# def pairwise(iterable):
#     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
#     a, b = tee(iterable)
#     next(b, None)
#     return list(zip(a, b))

In [None]:
class PointTransformerSegmentator(nn.Module):
    def __init__(self, layer_config, num_neighbours=16, sampling_ratio=0.25, dropouts=None):
        super().__init__()
        assert dropouts is None or len(dropouts) == 3
        
        in_dims, *inner_dims, out_dims = layer_config
        encoder_dims = pairwise(inner_dims)
        decoder_dims = pairwise(inner_dims[::-1])
        
        self.encoder_transitions = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_dims, inner_dims[0]),
                nn.ReLU(),
                nn.Linear(inner_dims[0], inner_dims[0]),
            ),
        ] + [
            TransitionDownBlock(
                in_dims=in_dims,
                out_dims=out_dims,
                num_neighbours=num_neighbours,
                sampling_ratio=sampling_ratio,
            )
            for in_dims, out_dims in encoder_dims
        ])
        self.encoder_transformers = nn.ModuleList([
            PointTransformerBlock(
                dim=dims,
                hidden_dim=dims,
                pos_mlp_hidden_dim=dims,
                attn_mlp_hidden_mult=4,
                num_neighbours=num_neighbours,
                dropouts=dropouts[:-1] if dropouts else None,
            )
            for dims in inner_dims
        ])
        
        self.decoder_transitions = nn.ModuleList([
            nn.Linear(inner_dims[-1], inner_dims[-1]),
        ] + [
            TransitionUpBlock(
                in_dims=in_dims,
                out_dims=out_dims,
            )
            for in_dims, out_dims in decoder_dims
        ])
        self.decoder_transformers = nn.ModuleList([
            PointTransformerBlock(
                dim=dims,
                hidden_dim=dims,
                pos_mlp_hidden_dim=dims,
                attn_mlp_hidden_mult=4,
                num_neighbours=num_neighbours,
                dropouts=dropouts[:-1] if dropouts else None,
            )
            for dims in inner_dims[::-1]
        ])
        
        self.drop = nn.Dropout(dropouts[-1] if dropouts else 0)
        
        self.fc_out = nn.Sequential(
            nn.Linear(inner_dims[0], inner_dims[0]),
            nn.ReLU(),
            nn.Linear(inner_dims[0], out_dims),
        )
    
    def forward(self, pos: torch.Tensor) -> torch.Tensor:
        lateral_features = []
        lateral_pos = []

        for i, (trans_down, transformer) in enumerate(zip(self.encoder_transitions, self.encoder_transformers)):
            if i == 0:
                features = trans_down(pos)
                features = transformer(features, pos)
            else:
                features, pos = trans_down(features, pos)
                features = transformer(features, pos)

            if i < len(self.encoder_transitions) - 1:
                lateral_features.append(features)
                lateral_pos.append(pos)

        for i, (trans_up, transformer) in enumerate(zip(self.decoder_transitions, self.decoder_transformers)):
            if i == 0:
                features = trans_up(features)
                features = transformer(features, pos)
            else:
                features, pos = trans_up(features, pos, lateral_features[-i], lateral_pos[-i])
                features = transformer(features, pos)

        out = self.drop(features)

        out = self.fc_out(out)
        
        return out

# Model evaluation

In [None]:
from dataclasses import dataclass

@dataclass
class CMCounts:
    tp: int = 0
    fp: int = 0
    fn: int = 0
    tn: int = 0
    
    # two tensors
    @classmethod
    def from_tensors(cls, target, preds):
        tp, fp, fn, tn = torch.dstack((
            preds & target > 0,
            preds > target,
            preds < target,
            preds | target == 0,
        )).sum((0, 1))
        return cls(*[x.item() for x in [tp, fp, fn, tn]])
        
        
    @property
    def f1(self):
        return self.tp / (self.tp + 0.5 * (self.fp + self.fn))
    
    @property
    def f2(self):
        return self.f_beta(2)
    
    def f_beta(self, beta):
        return (1 + beta ** 2) * self.tp / ((1 + beta ** 2) * self.tp + (beta ** 2) * self.fn + self.fp)
    
    @property
    def accuracy(self):
        return (self.tp + self.tn) / (self.tp + self.fp + self.fn + self.tn)
    
    @property
    def balanced_accuracy(self):
        true_positive_rate = self.tp / (self.tp + self.fn)
        true_negative_rate = self.tn / (self.tn + self.fp)
        return (true_positive_rate + true_negative_rate) / 2
    
    def __add__(self, other):
        return CMCounts(
            tp=self.tp + other.tp,
            fp=self.fp + other.fp,
            fn=self.fn + other.fn,
            tn=self.tn + other.tn,
        )
    
    def __radd__(self, other):
        return self if other == 0 else self + other

In [None]:
class SegmentationTask(pl.LightningModule):
    def __init__(self, model, max_lr, epochs, steps_per_epoch, num_classes=2):
        super().__init__()
        self.save_hyperparameters("max_lr", "epochs")
        self.steps_per_epoch = steps_per_epoch
        self.model = model
        self.loss = nn.BCEWithLogitsLoss(reduction="mean")
        
    def forward(self, x):
        return self.model(x)


    def _shared_step(self, batch, prefix):
        x, y, _ = batch
        logits = self.model(x).squeeze()
        loss = self.loss(logits, y.float())
        
        y_preds = (logits >= 0.0).long() # since we do not call sigmoid
        
        cm_counts = CMCounts.from_tensors(y.flatten(), y_preds.flatten())

        self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True)

        return { "loss": loss, "f1": cm_counts }
    
    
    # TODO: this is hard-coded, should be automated
    def on_train_start(self):
        self.logger.log_hyperparams({
            "bs": BS, "num_points": NUM_POINTS,
            "max_lr": self.hparams.max_lr, "epochs": self.hparams.epochs,
            "optimizer": "AdamW(wd=1e-2)", "scheduler": "OneCycleLR",
        })
    
    
    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")


    def training_epoch_end(self, outs):
        cm_count_total = sum(map(lambda x: x["f1"], outs))
        self.log("train_f1", cm_count_total.f1)


    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "valid")

    
    def validation_epoch_end(self, outs):
        cm_count_total = sum(map(lambda x: x["f1"], outs))
        self.log("valid_f1", cm_count_total.f1, )


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), weight_decay=1e-2)
        lr_scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.max_lr,
            steps_per_epoch=self.steps_per_epoch,
            epochs=self.hparams.epochs
        )
        lr_dict = { "scheduler": lr_scheduler, "interval": "step" }
        return [optimizer], [lr_dict]

In [None]:
layer_config = [3, 32, 64, 128, 256, 512, 1]

model = PointTransformerSegmentator(
    layer_config=layer_config,
    num_neighbours=16,
    sampling_ratio=0.25,
    dropouts=[0.0, 0.0, 0.0]
)

# Put your path to weights here
segmentator = SegmentationTask.load_from_checkpoint(
    "path/to/weights.ckpt",
    model=model,
    steps_per_epoch=len(dm.test_dataloader())
)

In [None]:
from collections import defaultdict

def evaluate(model, dl, threshold_values=None, quiet=False):
    
    if threshold_values is None:
        threshold_values = [0.5]

    model.cuda()
    model.eval()
    
    instances = sorted(list(set(dl.dataset.categories)))
    cm_count_dict = { key: defaultdict(CMCounts) for key in threshold_values }
    with torch.no_grad():
        for x, y, cats in tqdm(dl, total=len(dl)):
            x, y = x.cuda(), y.cuda()
            logits = model(x).squeeze()
            
            for th in threshold_values:
                segm_preds = (logits.sigmoid() >= th).long()

                cats = np.array(cats)
                for i, cat in enumerate(instances):
                    mask = cats == cat
                    cat_ys = y[mask].flatten()
                    cat_y_preds = segm_preds[mask].flatten()
                    cm_count_new = CMCounts.from_tensors(y[mask].flatten(), segm_preds[mask].flatten())
                    cm_count_dict[th][cat] += cm_count_new
                
    cm_count_totals = {
        th_value: sum(category_values.values()) for th_value, category_values in cm_count_dict.items()
    }
    if not quiet:
        key = 0.5 if 0.5 in threshold_values else threshold_values[0]
        print(f"F1: {cm_count_totals[key].f1}")
        print(f"F2: {cm_count_totals[key].f2}")
        print(f"Acc: {cm_count_totals[key].accuracy}")
        print(f"Balanced acc: {cm_count_totals[key].balanced_accuracy}")
    
    return cm_count_totals


In [None]:
# Test

# evaluate(segmentator, dm.test_dataloader());

In [None]:
a = 0.05
b = 1.0
step = 0.05
threshold_dict = {
    "threshold_values": np.arange(a, b, step),
    "metric_values": {
        "f1": [],
        "f2": [],
        "accuracy": [],
        "balanced_accuracy": [],
    },
}
print(threshold_dict["threshold_values"])

count_totals = evaluate(segmentator, dm.test_dataloader(), threshold_values=threshold_dict["threshold_values"], quiet=True)
threshold_dict["metric_values"]["f1"] = [count_totals[th_value].f1 for th_value in count_totals]
threshold_dict["metric_values"]["f2"] = [count_totals[th_value].f2 for th_value in count_totals]
threshold_dict["metric_values"]["accuracy"] = [count_totals[th_value].accuracy for th_value in count_totals]
threshold_dict["metric_values"]["balanced_accuracy"] = [count_totals[th_value].balanced_accuracy for th_value in count_totals]

In [None]:
# fig, axes = plt.subplots(1, 4, sharey=True, figsize=(25, 5)) # uncomment this for a 1x4 grid
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

benchmark_values = [0.77, 0.86, None, None]

translation_dict = { "f1": "f1 įvertis", "f2": "f2 įvertis", "accuracy": "Tikslumas", "balanced_accuracy": "Subalansuotas tikslumas" }

x = threshold_dict["threshold_values"]
max_value_indices = [x.index(max(x)) for _, x in threshold_dict["metric_values"].items()]
for ax, metric_label, max_value_index, benchmark_value in zip(axes.flat, list(threshold_dict["metric_values"].keys()), max_value_indices, benchmark_values):
    ax.plot(x, threshold_dict["metric_values"][metric_label], marker="o", markevery=[max_value_index])
    ax.set_xticks(np.arange(0.0, 1.0, 0.1))
    ax.set_title(translation_dict[metric_label])
    ax.set_xlabel("Slenksčio vertė")
    ax.annotate(
        f"Maksimumas taške ({threshold_dict['threshold_values'][max_value_index]:.2f}; {threshold_dict['metric_values'][metric_label][max_value_index]:.4f})",
        (
            threshold_dict["threshold_values"][max_value_index] - 0.15,
            threshold_dict["metric_values"][metric_label][max_value_index] + 0.015
        )
    )
    
    if benchmark_value is not None:
        ax.axhline(y=benchmark_value, color="grey", linestyle="--")
        ax.annotate(
        f"PointCleanNet: {benchmark_value:.2f}",
        (
            0.3,
            benchmark_value + 0.015
        )
    )
    
    ax.set_ylim(0.5, 1)

# Choose a path where you want to safe the figure
plt.savefig("benchmark_comparison_2x2.png", dpi=200)
# plt.show()

In [None]:
# check what happens when maximizing f1
f1_max_value_threshold = threshold_dict["threshold_values"][max_value_indices[0]]

evaluate(segmentator, dm.test_dataloader(), threshold_values=[f1_max_value_threshold]);

In [None]:
# check what happens when maximizing f2
f2_max_value_threshold = threshold_dict["threshold_values"][max_value_indices[1]]

evaluate(segmentator, dm.test_dataloader(), threshold_values=[f2_max_value_threshold]);