In [1]:
#Environment set-up and libraries

#Base libraries
import numpy as np
import random
import torch
import torch.nn as nn
from datetime import datetime

#Plotting libraries
%matplotlib inline
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import plotly
import plotly.graph_objects as go

#Utilities libraries
from glob import glob 
import os

import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Loading the point clouds

In [2]:
def load_file(file_name):
    print(file_name)

    if file_name.endswith(".las") or file_name.endswith(".laz"):
        print("[INFO] .las (.laz) file loading")
        try:
            # import lidar .las data and assign to variable
            pcd = laspy.read(file_name)
            # examine the available features for the lidar file we have read
            # list(las.point_format.dimension_names)
            #
            # set(list(las.classification))

            # Creating, Filtering, and Writing Point Cloud Data
            # To create 3D point cloud data, we can stack together with the X, Y, and Z dimensions, using Numpy like this.
            point_data = np.stack([pcd.X, pcd.Y, pcd.Z], axis=0).transpose((1, 0))
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(point_data)
            # points = point_data
            if pcd is not None:
                print("[Info] Successfully read", file_name)

                # Point cloud
                return pcd

        except Exception:
            print(".las, .laz file load failed")

    elif file_name.endswith(".e57"):
        print("[INFO] .e57 file loading")
        try:
            e57_file = pye57.E57(file_name)

            # other attributes can be read using:
            data = e57_file.read_scan(0)

            # 'data' is a dictionary with the point types as keys
            # assert isinstance(data["cartesianX"], np.ndarray)
            # assert isinstance(data["cartesianY"], np.ndarray)
            # assert isinstance(data["cartesianZ"], np.ndarray)

            point_xyz = np.stack([data["cartesianX"], data["cartesianY"], data["cartesianZ"]]).transpose((1, 0))
            # points_rgb = [data["colorRed"], data["colorGreen"], data["colorBlue"]]
            # points_intensity = data["intensity"]

            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(point_xyz)
            # points = o3d.utility.Vector3dVector(point_xyz)
            # points = point_xyz
            # pcd.colors = o3d.utility.Vector3dVector(points_rgb)
            # pcd.colors[:, 0] = points_intensity
            print("[Info] Successfully read", file_name)
            return pcd

        except Exception:
            print(".e57 file load failed")

    elif file_name.endswith(".bin"):
        print("[INFO] .bin file loading")
        try:
            size_float = 4
            list_pcd = []
            with open(file_name, "rb") as f:
                byte = f.read(size_float * 4)
                while byte:
                    x, y, z, intensity = struct.unpack("ffff", byte)
                    list_pcd.append([x, y, z])
                    byte = f.read(size_float * 4)
            np_pcd = np.asarray(list_pcd)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(np_pcd)
            print("[Info] Successfully read", file_name)
            return pcd

        except Exception:
            print(".bin file load failed")

    elif file_name.endswith(".ply"):
        pcd = o3d.io.read_point_cloud(file_name)
        points_xyz = np.asarray(pcd.points)
        #pcd = o3d.geometry.PointCloud() # No need to do that already a PointCloud
        pcd.points = o3d.utility.Vector3dVector(points_xyz)
        # points = points_xyz
        if pcd is not None:
            print("[Info] Successfully read", file_name)
            # Point cloud
            return pcd

    elif file_name.endswith(".pts"):
        try:
            with open(file_name, "r") as f:
                # Log every 1000000 lines.
                LOG_EVERY_N = 1000000
                points_np = []
                for line in f:
                    if len(line.split()) == 4:
                        x, y, z, i = [num for num in line.split()]
                        points_np.append([float(x), float(y), float(z), float(i)])
                        if (len(points_np) % LOG_EVERY_N) == 0:
                            print('point', len(points_np))
                    elif len(line.split()) == 3:
                        x, y, z = [num for num in line.split()]
                        points_np.append([float(x), float(y), float(z)])
                        if (len(points_np) % LOG_EVERY_N) == 0:
                            print('point', len(points_np))
                    elif len(line.split()) == 5:
                        x, y, z, i, zeroes_v = [num for num in line.split()]
                        points_np.append([float(x), float(y), float(z), float(i)])
                        if (len(points_np) % LOG_EVERY_N) == 0:
                            print('point', len(points_np))
                    elif len(line.split()) == 7:
                        x, y, z, r, g, b, i = [num for num in line.split()]
                        points_np.append([float(x), float(y), float(z),
                                          float(r), float(g), float(b),
                                          float(i)])
                        if (len(points_np) % LOG_EVERY_N) == 0:
                            print('point', len(points_np))
                    else:
                        print("[Info] The file has unregistered format")
                        return
            print('loop end')
            points_arr = np.array(points_np).transpose()
            print(len(points_arr))
            point_xyz = points_arr[:3].transpose()
            print("xyz points shape", point_xyz.shape)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(point_xyz)
            if len(points_arr) == 4:
                points_intensity = (points_arr[3])/255.0
                print("intensity points len", points_intensity.shape)
                points_intensity_rgb = np.vstack((points_intensity,
                                                  points_intensity,
                                                  points_intensity)).T
                print("intensity_rgb points shape", points_intensity_rgb.shape)
                pcd.colors = o3d.utility.Vector3dVector(points_intensity_rgb)
            elif len(points_arr) == 7:
                points_red = (points_arr[4]) / 255.0
                points_green = (points_arr[5]) / 255.0
                points_blue = (points_arr[6]) / 255.0
                points_rgb = np.vstack((points_red,
                                        points_green,
                                        points_blue)).T

                # points_intensity = ((points_arr[3]) / 255.0).T
                # print("intensity points len", points_intensity.shape)
                print("rgb points shape", points_rgb.shape)
                pcd.colors = o3d.utility.Vector3dVector(points_rgb)
                #pcd.intensities = o3d.utility.Vector3dVector(points_intensity)
            if pcd is not None:
                print("[Info] Successfully read", file_name)
                # Point cloud
                return pcd

        except Exception:
            print("[Info] Reading .pts file failed", file_name)

    # elif file_name.endswith(".kml"):
    #     try:
    #         with open(file_name, "r") as f:
    #             # Log every 1000000 lines.
    #             LOG_EVERY_N = 1000000
    #             points_np = []
    #             for line in f:
    #                 print(line)
    #                 if len(line.split(",")) == 3 and (line[0].isdigit() or line.startswith("-")):
    #                     y, x, z = [num for num in line.split(",")]
    #                     points_np.append([float(x), float(y), float(z)])
    #                     if (len(points_np) % LOG_EVERY_N) == 0:
    #                         print('point', len(points_np))
    #                 else:
    #                     print("[Info] The file has unregistered format")
    #         print('loop end')
    #         points_arr = np.array(points_np).transpose()
    #         print(len(points_arr))
    #         point_xyz = points_arr[:3].transpose()
    #         # points_intensity = points_arr[3]
    #         pcd = o3d.geometry.PointCloud()
    #         pcd.points = o3d.utility.Vector3dVector(point_xyz)
    #         if pcd is not None:
    #             print("[Info] Successfully read", file_name)
    #             # Point cloud
    #             return pcd
    #
    #     except Exception:
    #         print("[Info] Reading .kml file failed", file_name)

    else:
        pcd = None
        geometry_type = o3d.io.read_file_geometry_type(file_name)
        print(geometry_type)

        mesh = None
        if geometry_type & o3d.io.CONTAINS_TRIANGLES:
            mesh = o3d.io.read_triangle_model(file_name)
        if mesh is None:
            print("[Info]", file_name, "appears to be a point cloud")
            cloud = None
            try:
                cloud = o3d.io.read_point_cloud(file_name)
                # print(type(cloud))
            except Exception:
                print("[Info] Unknown filename", file_name)
            if cloud is not None:
                print("[Info] Successfully read", file_name)

                if not cloud.has_normals():
                    cloud.estimate_normals()
                cloud.normalize_normals()
                pcd = cloud
                #points = cloud.points
                pcd.points = o3d.utility.Vector3dVector(cloud.points)
            else:
                print("[WARNING] Failed to read points", file_name)

        if pcd is not None or mesh is not None:
            try:
                if mesh is not None:
                    # Triangle model
                    _scene.scene.add_model("__model__", mesh)
                else:
                    # Point cloud
                    return pcd

            except Exception as e:
                print(e)


In [3]:
filepath_mob1 = "/home/mekala/PycharmProjects/SabreProject_code/Sabre_proj/SABRE - Selected Static Scan Data/SABRE ADVANCED 3D - Selected MMS Data/"
filename1 = "SABRE MMS_S3 - 0002.pts"
filepath1 = filepath_mob1 + filename1
pc1 = load_file(filepath1)

filepath_static2 = "/home/mekala/PycharmProjects/SabreProject_code/Sabre_proj/SABRE - Selected Static Scan Data/SABRE - Selected Static Scan Data/"
filename2 = "SABRE Static Scan_T17_003.pts"
filepath2 = filepath_static2 + filename2
pc2 = load_file(filepath2)

/home/mekala/PycharmProjects/SabreProject_code/Sabre_proj/SABRE - Selected Static Scan Data/SABRE ADVANCED 3D - Selected MMS Data/SABRE MMS_S3 - 0002.pts
point 1000000
point 2000000
point 3000000
point 4000000
point 5000000
point 6000000
point 7000000
point 8000000
point 9000000
point 10000000
point 11000000
loop end
4
xyz points shape (11008825, 3)
intensity points len (11008825,)
intensity_rgb points shape (11008825, 3)
[Info] Successfully read /home/mekala/PycharmProjects/SabreProject_code/Sabre_proj/SABRE - Selected Static Scan Data/SABRE ADVANCED 3D - Selected MMS Data/SABRE MMS_S3 - 0002.pts
/home/mekala/PycharmProjects/SabreProject_code/Sabre_proj/SABRE - Selected Static Scan Data/SABRE - Selected Static Scan Data/SABRE Static Scan_T17_003.pts
point 1000000
point 2000000
point 3000000
point 4000000
point 5000000
point 6000000
point 7000000
point 8000000
point 9000000
point 10000000
point 11000000
point 12000000
point 13000000
point 14000000
point 15000000
point 16000000
point 17

#### pc1 and pc2 are the original size point clouds

In [4]:
'''
function that returns the down sampled point cloud 
and fpfh parameters of the down sampled point cloud
the down sampling defined by the parameter voxel_size
'''
def preprocess_point_cloud(pcd, voxel_size):
    pcd_down = pcd.voxel_down_sample(voxel_size)
    pcd_down.estimate_normals(
        o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2.0,
                                             max_nn=30))
    pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
        pcd_down,
        o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 5.0,
                                             max_nn=100))
    return pcd_down, pcd_fpfh

In [5]:
# Downsampling
voxel_size = 0.3

pc1_down, pc1_fpfh = preprocess_point_cloud(pc1, voxel_size)
pc2_down, pc2_fpfh = preprocess_point_cloud(pc2, voxel_size)

In [6]:
# Sizes of the downsampled point clouds points sets
pc1_down_points = np.asarray(pc1_down.points)
pc2_down_points = np.asarray(pc2_down.points)
print(pc1_down_points.shape)
print(pc2_down_points.shape)

(56746, 3)
(123234, 3)


#### Defining device for this notebook

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


### Initial RANSAC alignment

In [9]:
print("voxel_size = ", voxel_size)
distance_threshold = 2.5 * voxel_size
print("Distance threshold: ", distance_threshold)
mutual_filter = True
print("mutual_filter = ", mutual_filter)
max_iterations = 1000000
print("max_iterations = ", max_iterations)
max_validation = np.min([len(pc1_down.points), len(pc2_down.points)]) // 2
print("max_validation = ", max_validation)

# getting the current date and time
start = datetime.now()
# getting the date and time from the current date and time in the given format
start_date_time = start.strftime("%m/%d/%Y, %H:%M:%S")
print('\nRANSAC Started', start_date_time, '\n')
print('Running RANSAC\n')
result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
    pc1_down, pc2_down, pc1_fpfh, pc2_fpfh,
    mutual_filter=mutual_filter,
    max_correspondence_distance=distance_threshold,
    estimation_method=o3d.pipelines.registration.
    TransformationEstimationPointToPoint(True),
    ransac_n=3,
    checkers=[
        o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
        o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)
    ],
    criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(
        max_iterations, max_validation))  # max_validation replaces args.confidence in mobile-static
# getting the current date and time
finish = datetime.now()
# getting the date and time from the current date and time in the given format
finish_date_time = finish.strftime("%m/%d/%Y, %H:%M:%S")
print('RANSAC Finished', finish_date_time,
      "\nGlobal registration took %.3f sec.\n" % (finish - start).total_seconds())


voxel_size =  0.3
Distance threshold:  0.75
mutual_filter =  True
max_iterations =  1000000
max_validation =  28373

RANSAC Started 10/23/2023, 10:29:21 

Running RANSAC

RANSAC Finished 10/23/2023, 10:29:25 
Global registration took 4.869 sec.



#### RANSAC transformation matrix

In [10]:
trans = result.transformation
print("The estimated transformation matrix:")
print(trans)
print("Saving the transformation matrix in ransac_transformation_matrix.txt ...")
np.savetxt('ransac_transformation_matrix.txt', trans)
print("")

The estimated transformation matrix:
[[-9.61128621e-01 -1.34747476e-01  8.03138338e-03  1.73479135e+06]
 [ 1.34977219e-01 -9.60040389e-01  4.57516409e-02  6.79695670e+06]
 [ 1.59241253e-03  4.64239161e-02  9.69449274e-01 -2.94502707e+05]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]
Saving the transformation matrix in ransac_transformation_matrix.txt ...



#### Applying RANSAC transformation on original and downsampled point clouds and visualising the result with original point clouds

In [11]:
pc1_down_ransac = pc1_down.transform(result.transformation)
pc1_ransac = pc1.transform(result.transformation)

In [16]:
pc1_down_ransac_points = np.asarray(pc1_down_ransac.points)
print(pc1_down_ransac_points.shape)
print(pc2_down_points.shape)

(56746, 3)
(123234, 3)


In [13]:
# coloring the point clouds
source_color=(1, 0.706, 0)
target_color=(0, 0.651, 0.929)
pc1_ransac.paint_uniform_color(source_color)
pc2.paint_uniform_color(target_color)

PointCloud with 17814760 points.

In [14]:
o3d.visualization.draw_geometries([pc1_ransac, pc2])

#### Cropping downsampled point cloud 2 (static scan) to the size of downsampled and transformed point cloud 1 (mobile scan)

In [15]:
# Crop point cloud 2 to the size of transformed point cloud 1
oriented_bounding_box = pc1_down_ransac.get_oriented_bounding_box()
oriented_bounding_box.color = (0, 1, 0)
pc2_down_croppped = pc2_down.crop(oriented_bounding_box)


In [17]:
# coloring the downsampled, traansformed point cloud 1 
# and the cropped point cloud 2
pc1_down_ransac.paint_uniform_color(source_color)
pc2_down_croppped.paint_uniform_color(target_color)

PointCloud with 88128 points.

In [84]:
# Both point clouds with visualization of the bbox
o3d.visualization.draw_geometries([pc1_down_ransac, pc2_down_croppped, oriented_bounding_box, pc1_down_ransac])

### Chamfer distance

In [19]:
def compute_chamfer_distance(pcd1, pcd2):
    """
    Compute the Chamfer distance between two point clouds.

    Parameters:
    - pcd1, pcd2: Open3D point cloud objects.

    Returns:
    - chamfer_distance: The Chamfer distance between the two point clouds.
    """
    
    # Compute distance from pcd1 to pcd2
    distances_1_to_2 = pcd1.compute_point_cloud_distance(pcd2)
    avg_distance_1_to_2 = np.mean([np.min(dist) for dist in distances_1_to_2])

    # Compute distance from pcd2 to pcd1
    distances_2_to_1 = pcd2.compute_point_cloud_distance(pcd1)
    avg_distance_2_to_1 = np.mean([np.min(dist) for dist in distances_2_to_1])

    # Compute the Chamfer distance
    chamfer_distance = (avg_distance_1_to_2 + avg_distance_2_to_1) / 2

    return chamfer_distance

In [20]:
chamfer_dist = compute_chamfer_distance(pc1_down_ransac, pc2_down_croppped)
print(f"Chamfer Distance: {chamfer_dist}")

Chamfer Distance: 2.518630661043157


#### RANSAC Evaluation

In [21]:
#RANSAC Evaluation

fitness = result.fitness
print("Fitness:")
print(fitness)
print("")

rmse = result.inlier_rmse
print("RMSE of all inlier correspondences:")
print(rmse)
print("")

# trans = result.transformation
# print("The estimated transformation matrix:")
# print(trans)
# print("Saving the transformation matrix in ransac_transformation_matrix.txt ...")
# np.savetxt('ransac_transformation_matrix.txt', trans)
# print("")

correspondences = result.correspondence_set
print("Correspondence Set:")
print(correspondences)
print("")

Fitness:
0.7057589962288091

RMSE of all inlier correspondences:
0.3158831683339917

Correspondence Set:
std::vector<Eigen::Vector2i> with 40049 elements.
Use numpy.asarray() to access data.



In [22]:
def registration_error(sour, targ):
    # # Make source and target of the same size
    # minimum_len = min(len(sour), len(targ))
    # source = sour[:minimum_len, :3]
    # target = sour[:minimum_len, :3]
    # # Apply transformation to point cloud
    # source_transformed = np.dot(transformation[:3, :3], source.T).T + transformation[:3, 3]
    # # Compute the difference between the transformed source and target point clouds
    # diff = np.subtract(target, source_transformed)
    # # RMSE of the difference
    # rmse = np.sqrt(np.mean(np.sum(diff ** 2, axis=1)))
    # # Compute the rotational error using quaternions
    # r = R.from_matrix(transformation)
    # q = r.as_quat()
    # q_target = R.from_matrix(np.identity(3)).as_quat()
    # rot_error = np.arccos(np.abs(np.dot(q, q_target))) * 180 / np.pi
    # # Compute the translational error
    # trans_error = np.linalg.norm(transformation - np.array([0, 0, 0]))
    # return rmse, rot_error, trans_error
    print('Calculating errors...')
    # Calculate the centroid of the source and target points
    source_centroid = np.mean(sour, axis=0)
    target_centroid = np.mean(targ, axis=0)
    print(f'Sour centroid: {source_centroid}')
    print(f'Targ centroid: {target_centroid}')

    # Calculate the covariance matrix of the source and target points
    source_covariance = np.cov(sour.T)
    target_covariance = np.cov(targ.T)

    # Calculate the singular value decomposition of the covariance matrices
    U_source, S_source, Vt_source = np.linalg.svd(source_covariance)
    U_target, S_target, Vt_target = np.linalg.svd(target_covariance)

    # Calculate the rotation matrix
    rot = Vt_target.T @ U_source.T

    # Calculate the translation vector
    transl = target_centroid - rot @ source_centroid
    print(f'Transl vector: {transl}')

    rot_err = rot - np.eye(3)
    # Mean Absolute error for each axis (row in rot_err)
    rot_mae_xyz = np.mean(np.abs(rot_err), axis=1)

    # Calculating translational error
    transl_xyz = np.divide(np.abs(transl), (np.abs(source_centroid)+np.abs(target_centroid)+np.abs(transl))/3)
    transl_xyz_mae = np.divide(transl_xyz, 100)
    # Calculate the mean squared error
    #mse = np.mean(np.sum((targ - (sour @ rot.T + transl)) ** 2, axis=1))

    return rot_mae_xyz, transl_xyz_mae


In [23]:
# We have pc1_down_ransac_points need pc2_down_croppped_points
pc2_down_croppped_points = np.asarray(pc2_down_croppped.points)

rot_err, transl_err = registration_error(pc1_down_ransac_points, pc2_down_croppped_points)
print(f'Rotational MAE error xyz: {rot_err}, Translational MAE error xyz: {transl_err}')
print(f'Rotational MAE: {np.mean(rot_err)}, Translational MAE: {np.mean(transl_err)}')
print("")

Calculating errors...
Sour centroid: [3.71370502e+05 7.96958317e+05 6.60483966e+01]
Targ centroid: [3.71370216e+05 7.96933154e+05 7.02526278e+01]
Transl vector: [-299824.58809461  229044.24017572    7707.697467  ]
Rotational MAE error xyz: [0.18740532 0.19354766 0.04054352], Translational MAE error xyz: [0.00862751 0.00376937 0.02947871]
Rotational MAE: 0.1404988334065751, Translational MAE: 0.013958528487968835



## TRANSFORMER

### Data Preparation

#### Downsampling again to fit the memory with transformer

In [26]:
# Preprocess point cloud data. One more downsampling
voxel_size = 0.8  # Adjust as needed
source_pc_down, source_fpfh = preprocess_point_cloud(pc1_down_ransac, voxel_size)
target_pc_down, target_fpfh = preprocess_point_cloud(pc2_down_croppped, voxel_size)


In [27]:
# Downsampling result
source_pc_down_points = np.asarray(source_pc_down.points)
target_pc_down_points = np.asarray(target_pc_down.points)
print(source_pc_down_points.shape)
print(target_pc_down_points.shape)

(8886, 3)
(16880, 3)


In [51]:
o3d.visualization.draw_geometries([source_pc_down])

#### Creating batches

In [28]:
# Define the desired number of points for each batch 
# batch_size = len(batch_sizes) = 8
batch_size = 54
batch_sizes = [320]*batch_size  # Adjust as needed
#batch_sizes[0] = 2048
print(sum(batch_sizes))
batch_sizes

17280


[320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320]

#### The function below creates for each point cloud batch non overlapping batch with fpfh parameters batch. When a batch is smaller then a batch_size the function adds the padding. At the end it makes the cuda points and fpfh tensors of floats for the PyTorch Transformer input.

In [36]:
# Non overlaping batches

def create_batches_with_padding(pcd, batch_sizes):
    num_batches = len(batch_sizes)
    batches_points = []
    batches_fpfh = []
    batch_start = 0
    points = np.asarray(pcd.points)

    for i in range(num_batches):
        batch_size = batch_sizes[i]
        print('batch_size', batch_size)

        # Initialize empty arrays for the current batch
        batch_points = []
        batch_fpfh = []
        
        # Cut the point cloud points to the size of the batch
        if (len(points)-batch_start)>0:
            batch_points = points[batch_start:(batch_start+batch_size)]
        
        # Calculate padding sizes
        pad_points = batch_size - len(batch_points)
        print('pad_points ', pad_points)

        # Pad point cloud and FPFH to match the batch size
        if len(batch_points)>0:
            batch_points = np.pad(batch_points, [(0, pad_points), (0, 0)], mode='constant')


            # FPFH for the points cut
            batch_point_cloud = o3d.geometry.PointCloud()
            batch_point_cloud.points = o3d.utility.Vector3dVector(batch_points)

            batch_point_cloud.estimate_normals(
                o3d.geometry.KDTreeSearchParamHybrid(
                    radius=voxel_size * 2.0, max_nn=30))
            fpfh = o3d.pipelines.registration.compute_fpfh_feature(
                batch_point_cloud, o3d.geometry.KDTreeSearchParamHybrid(
                    radius=voxel_size * 5.0, max_nn=100))



        # Convert the batch to PyTorch tensors
        batch_points = torch.FloatTensor(batch_points).cuda()#, dtype=torch.float32)
        #batch_fpfh = torch.tensor(fpfh, dtype=torch.float32)
        batch_fpfh = torch.FloatTensor(np.asarray(fpfh.data).copy()).T.cuda()

        batches_points.append(batch_points)
        batches_fpfh.append(batch_fpfh)
        batch_start += batch_size

    return batches_points, batches_fpfh

In [37]:
sour_batches_points, sour_batches_fpfh = create_batches_with_padding(source_pc_down, batch_sizes)


batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  74
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
b

In [38]:
print(len(sour_batches_points))
print(len(sour_batches_fpfh))
print(sour_batches_points[1].shape)
print(sour_batches_fpfh[1].shape)

54
54
torch.Size([320, 3])
torch.Size([320, 33])


In [40]:
targ_batches_points, targ_batches_fpfh = create_batches_with_padding(target_pc_down, batch_sizes)


batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320

In [41]:
print(len(targ_batches_points))
print(len(targ_batches_fpfh))
print(targ_batches_points[1].shape)
print(targ_batches_fpfh[1].shape)

54
54
torch.Size([320, 3])
torch.Size([320, 33])


#### Test

In [60]:
(targ_batches_points[52][300].cpu().detach().numpy() != [0,0,0]).all()

False

In [48]:
targ_batches_points[1].shape

torch.Size([320, 3])

In [82]:
target_test_points = []
target_test_pcd = o3d.geometry.PointCloud()

for i in range(len(targ_batches_points)):
    # Convert the aligned data to a NumPy array of shape (N, 3)
    target_batch = targ_batches_points[i].cpu().detach().numpy()  # Assuming 'aligned_source' is a PyTorch tensor
    target_batch_points = []
    if len(target_batch)>0:
        for point in target_batch:
            if (point != [0,0,0]).all():
                target_batch_points.append(point)
    
    if len(target_batch_points)>0:
        # Store aligned data
        target_test_points.append(target_batch_points)
        # Create an Open3D point cloud and assign the aligned data
        target_batch_pcd = o3d.geometry.PointCloud()
        target_batch_pcd.points = o3d.utility.Vector3dVector(target_batch_points)
        target_test_pcd += target_batch_pcd
        o3d.visualization.draw_geometries([target_batch_pcd])
    
o3d.visualization.draw_geometries([target_test_pcd])    

#### Simple PyTorch transformer model

In [67]:
from torch.nn import Transformer

# model = Transformer().cuda()
# src = torch.rand((10, 32, 512)).float().cuda()
# tgt = torch.rand((20, 32, 512)).float().cuda()

# print(model(src, tgt).shape)

aligned_pcd_points = []
results = []
aligned_point_cloud = o3d.geometry.PointCloud()

for i in range(len(targ_batches_points)):
    # Initialize the transformer model
    input_dim = len(targ_batches_points[i])*2 # Define your input dimension
    print(f'input_dim = {input_dim}')
    num_heads = 3   # Number of attention heads
    num_layers = 6  # Number of transformer layers
    d_model = 3   # Dimension of the embedding vectors
    print(f'd_model = {d_model}')
    hidden_dim = 256 # Hidden dimension in feed-forward layers
    
    target_points = targ_batches_points[i]
    target_fpfh = targ_batches_fpfh[i]
    print(f'Processing target batch {i} with {len(target_points)} points')

    model = Transformer(d_model=d_model, nhead=num_heads).cuda()
    source_points = sour_batches_points[i]
    source_fpfh = sour_batches_fpfh[i]
    print(f'Target and source batch {i} with {len(source_points)} points')

    if len(target_points)>0 and len(source_points)>0:
        result = model(source_points, target_points)
        results.append(result)
        # print(result)
        # Convert the aligned data to a NumPy array of shape (N, 3)
        aligned_data_numpy = result.cpu().detach().numpy()  # Assuming 'aligned_source' is a PyTorch tensor

        aligned_batch_points = []
        if len(aligned_data_numpy)>0:
            for point in aligned_data_numpy:
                if (point != [0,0,0]).all():
                    aligned_batch_points.append(point)

        if len(aligned_batch_points)>0:
            # Store aligned data
            aligned_pcd_points.append(aligned_batch_points)
            # Create an Open3D point cloud and assign the aligned data
            aligned_batch_pcd = o3d.geometry.PointCloud()
            aligned_batch_pcd.points = o3d.utility.Vector3dVector(aligned_batch_points)
            aligned_point_cloud += aligned_batch_pcd
            o3d.visualization.draw_geometries([aligned_batch_pcd])
            
        rot_err, transl_err = registration_error(aligned_data_numpy, targ_batches_points[i].cpu().detach().numpy())
        print(f'Rotational MAE error xyz: {rot_err}, \nTranslational MAE error xyz: {transl_err}')
        print(f'Rotational MAE: {np.mean(rot_err)}, \nTranslational MAE: {np.mean(transl_err)}')
        print("")
        

o3d.visualization.draw_geometries([aligned_point_cloud])

input_dim = 640
d_model = 3
Processing target batch 0 with 320 points
Target and source batch 0 with 320 points
Calculating errors...
Sour centroid: [-0.9369685   1.3400666  -0.40309793]
Targ centroid: [3.7136744e+05 7.9690088e+05 7.3881683e+01]
Transl vector: [3.71367250e+05 7.96899223e+05 7.41528393e+01]
Rotational MAE error xyz: [0.51383794 0.23968485 0.50780588], 
Translational MAE error xyz: [0.01499998 0.01499997 0.01498667]
Rotational MAE: 0.4204428872414474, 
Translational MAE: 0.014995538735124428

input_dim = 640
d_model = 3
Processing target batch 1 with 320 points
Target and source batch 1 with 320 points
Calculating errors...
Sour centroid: [-1.3848673   0.64962715  0.73524034]
Targ centroid: [3.7138044e+05 7.9690300e+05 7.3456146e+01]
Transl vector: [3.71381883e+05 7.96903884e+05 7.33679388e+01]
Rotational MAE error xyz: [0.34411054 0.57547547 0.45376091], 
Translational MAE error xyz: [0.015      0.015      0.01491629]
Rotational MAE: 0.4577823056251975, 
Translational M

Calculating errors...
Sour centroid: [-1.2606348   0.39213404  0.86850023]
Targ centroid: [3.713884e+05 7.969263e+05 7.143745e+01]
Transl vector: [3.71388298e+05 7.96924754e+05 7.16745844e+01]
Rotational MAE error xyz: [0.86452867 0.69193159 0.58800932], 
Translational MAE error xyz: [0.01499997 0.01499998 0.01493422]
Rotational MAE: 0.7148231943268838, 
Translational MAE: 0.014978059539020037

input_dim = 640
d_model = 3
Processing target batch 17 with 320 points
Target and source batch 17 with 320 points
Calculating errors...
Sour centroid: [ 0.36253163  0.51300716 -0.875539  ]
Targ centroid: [3.7137691e+05 7.9692844e+05 7.0720100e+01]
Transl vector: [3.71376783e+05 7.96927380e+05 7.08862834e+01]
Rotational MAE error xyz: [0.33886444 0.46812981 0.44133815], 
Translational MAE error xyz: [0.01499999 0.01499999 0.01492532]
Rotational MAE: 0.41611079840260734, 
Translational MAE: 0.014975098823191908

input_dim = 640
d_model = 3
Processing target batch 18 with 320 points
Target and sour

### Transformer 2 (with overlapping batches)

#### Overlaping batches with 20 points overlap

In [72]:
print(source_pc_down_points.shape)
print(target_pc_down_points.shape)

(8886, 3)
(16880, 3)


In [73]:
# Define the desired number of points for each batch 
# batch_size = len(batch_sizes) = 8
batch_size = 58
batch_sizes = [320]*batch_size  # Adjust as needed
#batch_sizes[0] = 2048
print(sum(batch_sizes)-20*batch_size)
batch_sizes

17400


[320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320,
 320]

In [85]:
# Overlaping batches

def create_overlapping_batches_with_padding(pcd, batch_sizes):
    num_batches = len(batch_sizes)
    batches_points = []
    batches_fpfh = []
    batch_start = 0
    points = np.asarray(pcd.points)

    for i in range(num_batches):
        batch_size = batch_sizes[i]
        print('batch_size', batch_size)

        # Initialize empty arrays for the current batch
        batch_points = []
        batch_fpfh = []
        
        # Cut the point cloud points to the size of the batch
        if (len(points)-batch_start)>0:
            batch_points = points[batch_start:(batch_start+batch_size)]
        
        # Calculate padding sizes
        pad_points = batch_size - len(batch_points)
        print('pad_points ', pad_points)

        # Pad point cloud and FPFH to match the batch size
        if len(batch_points)>0:
            batch_points = np.pad(batch_points, [(0, pad_points), (0, 0)], mode='constant')


            # FPFH for the points cut
            batch_point_cloud = o3d.geometry.PointCloud()
            batch_point_cloud.points = o3d.utility.Vector3dVector(batch_points)

            batch_point_cloud.estimate_normals(
                o3d.geometry.KDTreeSearchParamHybrid(
                    radius=voxel_size * 2.0, max_nn=30))
            fpfh = o3d.pipelines.registration.compute_fpfh_feature(
                batch_point_cloud, o3d.geometry.KDTreeSearchParamHybrid(
                    radius=voxel_size * 5.0, max_nn=100))



        # Convert the batch to PyTorch tensors
        batch_points = torch.FloatTensor(batch_points).cuda()#, dtype=torch.float32)
        #batch_fpfh = torch.tensor(fpfh, dtype=torch.float32)
        batch_fpfh = torch.FloatTensor(np.asarray(fpfh.data).copy()).T.cuda()

        batches_points.append(batch_points)
        batches_fpfh.append(batch_fpfh)
        batch_start += (batch_size - 20)

    return batches_points, batches_fpfh

In [86]:
sour_batches_points, sour_batches_fpfh = create_overlapping_batches_with_padding(source_pc_down, batch_sizes)


batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  134
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batch_size 320
pad_points  320
batc

In [87]:
print(len(sour_batches_points))
print(len(sour_batches_fpfh))
print(sour_batches_points[1].shape)
print(sour_batches_fpfh[1].shape)

58
58
torch.Size([320, 3])
torch.Size([320, 33])


In [88]:
targ_batches_points, targ_batches_fpfh = create_overlapping_batches_with_padding(target_pc_down, batch_sizes)


batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320
pad_points  0
batch_size 320

In [89]:
print(len(targ_batches_points))
print(len(targ_batches_fpfh))
print(targ_batches_points[1].shape)
print(targ_batches_fpfh[1].shape)

58
58
torch.Size([320, 3])
torch.Size([320, 33])


#### Transformer result loop

In [90]:
from torch.nn import Transformer

# model = Transformer().cuda()
# src = torch.rand((10, 32, 512)).float().cuda()
# tgt = torch.rand((20, 32, 512)).float().cuda()

# print(model(src, tgt).shape)

aligned_pcd_points = []
results = []
aligned_point_cloud = o3d.geometry.PointCloud()

for i in range(len(targ_batches_points)):
    # Initialize the transformer model
    input_dim = len(targ_batches_points[i])*2 # Define your input dimension
    print(f'input_dim = {input_dim}')
    num_heads = 3   # Number of attention heads
    num_layers = 6  # Number of transformer layers
    d_model = 3   # Dimension of the embedding vectors
    print(f'd_model = {d_model}')
    hidden_dim = 256 # Hidden dimension in feed-forward layers
    
    target_points = targ_batches_points[i]
    target_fpfh = targ_batches_fpfh[i]
    print(f'Processing target batch {i} with {len(target_points)} points')

    model = Transformer(d_model=d_model, nhead=num_heads).cuda()
    source_points = sour_batches_points[i]
    source_fpfh = sour_batches_fpfh[i]
    print(f'Target and source batch {i} with {len(source_points)} points')

    if len(target_points)>0 and len(source_points)>0:
        result = model(source_points, target_points)
        results.append(result)
        # print(result)
        # Convert the aligned data to a NumPy array of shape (N, 3)
        aligned_data_numpy = result.cpu().detach().numpy()  # Assuming 'aligned_source' is a PyTorch tensor

        aligned_batch_points = []
        if len(aligned_data_numpy)>0:
            for point in aligned_data_numpy:
                if (point != [0,0,0]).all():
                    aligned_batch_points.append(point)

        if len(aligned_batch_points)>0:
            # Store aligned data
            aligned_pcd_points.append(aligned_batch_points)
            # Create an Open3D point cloud and assign the aligned data
            aligned_batch_pcd = o3d.geometry.PointCloud()
            aligned_batch_pcd.points = o3d.utility.Vector3dVector(aligned_batch_points)
            aligned_point_cloud += aligned_batch_pcd
            o3d.visualization.draw_geometries([aligned_batch_pcd])
            
        rot_err, transl_err = registration_error(aligned_data_numpy, targ_batches_points[i].cpu().detach().numpy())
        print(f'Rotational MAE error xyz: {rot_err}, \nTranslational MAE error xyz: {transl_err}')
        print(f'Rotational MAE: {np.mean(rot_err)}, \nTranslational MAE: {np.mean(transl_err)}')
        print("")
        

o3d.visualization.draw_geometries([aligned_point_cloud])

input_dim = 640
d_model = 3
Processing target batch 0 with 320 points
Target and source batch 0 with 320 points
Calculating errors...
Sour centroid: [ 0.13274597 -1.0261794   0.8934337 ]
Targ centroid: [3.7136744e+05 7.9690088e+05 7.3881683e+01]
Transl vector: [3.71368604e+05 7.96901578e+05 7.37681830e+01]
Rotational MAE error xyz: [0.64927631 0.35285773 0.53253294], 
Translational MAE error xyz: [0.01500002 0.015      0.01489832]
Rotational MAE: 0.5115556604206045, 
Translational MAE: 0.014966112497844757

input_dim = 640
d_model = 3
Processing target batch 1 with 320 points
Target and source batch 1 with 320 points
Calculating errors...
Sour centroid: [-1.2145312  0.8861224  0.3284089]
Targ centroid: [3.7137991e+05 7.9690256e+05 7.3312202e+01]
Transl vector: [3.71378978e+05 7.96901343e+05 7.34488540e+01]
Rotational MAE error xyz: [0.87561226 0.73625275 0.58231752], 
Translational MAE error xyz: [0.01499996 0.01499998 0.01498044]
Rotational MAE: 0.731394176559853, 
Translational MAE: 

Calculating errors...
Sour centroid: [-0.07806135  1.1417259  -1.0636644 ]
Targ centroid: [3.7138312e+05 7.9692762e+05 7.0983849e+01]
Transl vector: [3.71382201e+05 7.96926382e+05 7.11906131e+01]
Rotational MAE error xyz: [0.42560011 0.23866861 0.49220298], 
Translational MAE error xyz: [0.01499998 0.01499998 0.01491026]
Rotational MAE: 0.3854905656437612, 
Translational MAE: 0.014970074106599299

input_dim = 640
d_model = 3
Processing target batch 17 with 320 points
Target and source batch 17 with 320 points
Calculating errors...
Sour centroid: [-0.98806936 -0.06574559  1.0538146 ]
Targ centroid: [3.7138906e+05 7.9692588e+05 7.1603203e+01]
Transl vector: [3.71389712e+05 7.96924599e+05 7.18078923e+01]
Rotational MAE error xyz: [0.83492381 0.62316414 0.57832595], 
Translational MAE error xyz: [0.01499999 0.01499999 0.01491183]
Rotational MAE: 0.6788046321000124, 
Translational MAE: 0.014970604708726956

input_dim = 640
d_model = 3
Processing target batch 18 with 320 points
Target and so

Target and source batch 40 with 0 points
input_dim = 640
d_model = 3
Processing target batch 41 with 320 points
Target and source batch 41 with 0 points
input_dim = 640
d_model = 3
Processing target batch 42 with 320 points
Target and source batch 42 with 0 points
input_dim = 640
d_model = 3
Processing target batch 43 with 320 points
Target and source batch 43 with 0 points
input_dim = 640
d_model = 3
Processing target batch 44 with 320 points
Target and source batch 44 with 0 points
input_dim = 640
d_model = 3
Processing target batch 45 with 320 points
Target and source batch 45 with 0 points
input_dim = 640
d_model = 3
Processing target batch 46 with 320 points
Target and source batch 46 with 0 points
input_dim = 640
d_model = 3
Processing target batch 47 with 320 points
Target and source batch 47 with 0 points
input_dim = 640
d_model = 3
Processing target batch 48 with 320 points
Target and source batch 48 with 0 points
input_dim = 640
d_model = 3
Processing target batch 49 with 320

### Transformer 3 self made (with overlapping batches)

Tutorial: https://deeplearning.neuromatch.io/tutorials/W2D5_AttentionAndTransformers/student/W2D5_Tutorial1.html#training-the-transformer

In [92]:
class DotProductAttention(nn.Module):
  """ Scaled dot product attention. """

  def __init__(self, dropout, **kwargs):
    """
    Constructs a Scaled Dot Product Attention Instance.

    Args:
      dropout: Integer
        Specifies probability of dropout hyperparameter

    Returns:
      Nothing
    """
    super(DotProductAttention, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)

  def calculate_score(self, queries, keys):
      """
      Compute the score between queries and keys.

      Args:
      queries: Tensor
        Query is your search tag/Question
        Shape of `queries`: (`batch_size`, no. of queries, head,`k`)
      keys: Tensor
        Descriptions associated with the database for instance
        Shape of `keys`: (`batch_size`, no. of key-value pairs, head, `k`)
      """
      return torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(queries.shape[-1])

  def forward(self, queries, keys, values, b, h, t, k):
    """
    Compute dot products. This is the same operation for each head,
    so we can fold the heads into the batch dimension and use torch.bmm
    Note: .contiguous() doesn't change the actual shape of the data,
    but it rearranges the tensor in memory, which will help speed up the computation
    for this batch matrix multiplication.
    .transpose() is used to change the shape of a tensor. It returns a new tensor
    that shares the data with the original tensor. It can only swap two dimensions.

    Args:
      queries: Tensor
        Query is your search tag/Question
        Shape of `queries`: (`batch_size`, no. of queries, head,`k`)
      keys: Tensor
        Descriptions associated with the database for instance
        Shape of `keys`: (`batch_size`, no. of key-value pairs, head, `k`)
      values: Tensor
        Values are returned results on the query
        Shape of `values`: (`batch_size`, head, no. of key-value pairs,  `k`)
      b: Integer
        Batch size
      h: Integer
        Number of heads
      t: Integer
        Number of keys/queries/values (for simplicity, let's assume they have the same sizes)
      k: Integer
        Embedding size

    Returns:
      out: Tensor
        Matrix Multiplication between the keys, queries and values.
    """
    keys = keys.transpose(1, 2).contiguous().view(b * h, t, k)
    queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
    values = values.transpose(1, 2).contiguous().view(b * h, t, k)

    # Matrix Multiplication between the keys and queries
    score = self.calculate_score(queries, keys)  # size: (b * h, t, t)
    softmax_weights = F.softmax(score, dim=2)  # row-wise normalization of weights

    # Matrix Multiplication between the output of the key and queries multiplication and values.
    out = torch.bmm(self.dropout(softmax_weights), values).view(b, h, t, k)  # rearrange h and t dims
    out = out.transpose(1, 2).contiguous().view(b, t, h * k)

    return out

In [93]:
class SelfAttention(nn.Module):
  """  Multi-head self attention layer. """

  def __init__(self, k, heads=8, dropout=0.1):
    """
    Initiates the following attributes:
    to_keys: Transforms input to k x k*heads key vectors
    to_queries: Transforms input to k x k*heads query vectors
    to_values: Transforms input to k x k*heads value vectors
    unify_heads: combines queries, keys and values to a single vector

    Args:
      k: Integer
        Size of attention embeddings
      heads: Integer
        Number of attention heads

    Returns:
      Nothing
    """
    super().__init__()
    self.k, self.heads = k, heads

    self.to_keys = nn.Linear(k, k * heads, bias=False)
    self.to_queries = nn.Linear(k, k * heads, bias=False)
    self.to_values = nn.Linear(k, k * heads, bias=False)
    self.unify_heads = nn.Linear(k * heads, k)

    self.attention = DotProductAttention(dropout)

  def forward(self, x):
    """
    Implements forward pass of self-attention layer

    Args:
      x: Tensor
        Batch x t x k sized input

    Returns:
      unify_heads: Tensor
        Self-attention based unified Query/Value/Key tensors
    """
    b, t, k = x.size()
    h = self.heads

    # We reshape the queries, keys and values so that each head has its own dimension
    queries = self.to_queries(x).view(b, t, h, k)
    keys = self.to_keys(x).view(b, t, h, k)
    values = self.to_values(x).view(b, t, h, k)

    out = self.attention(queries, keys, values, b, h, t, k)

    return self.unify_heads(out)

In practice PyTorch’s torch.nn.MultiheadAttention() function is used.

In [None]:
class Transformer(nn.Module):
  """ Transformer Encoder network for classification. """

  def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
    """
    Initiates the Transformer Network

    Args:
      k: Integer
        Attention embedding size
      heads: Integer
        Number of self attention heads
      depth: Integer
        Number of Transformer Blocks
      seq_length: Integer
        Length of input sequence
      num_tokens: Integer
        Size of dictionary
      num_classes: Integer
        Number of output classes

    Returns:
      Nothing
    """
    super().__init__()

    self.k = k
    self.num_tokens = num_tokens
    self.token_embedding = nn.Embedding(num_tokens, k)
    self.pos_enc = PositionalEncoding(k)

    transformer_blocks = []
    for i in range(depth):
      transformer_blocks.append(TransformerBlock(k=k, heads=heads))

    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.classification_head = nn.Linear(k, num_classes)

  def forward(self, x):
    """
    Forward pass for Classification within Transformer network

    Args:
      x: Tensor
        (b, t) sized tensor of tokenized words

    Returns:
      logprobs: Tensor
        Log-probabilities over classes sized (b, c)
    """
    x = self.token_embedding(x) * np.sqrt(self.k)
    x = self.pos_enc(x)
    x = self.transformer_blocks(x)

    sequence_avg = x.mean(dim=1)
    x = self.classification_head(sequence_avg)
    logprobs = F.log_softmax(x, dim=1)
    return logprobs