# Landmark Registration with Linear Transformations

## Test Data Generation

In [None]:
import numpy as np
np.random.seed(42)

fixed_landmark = np.random.randn(5, 3)
fixed_landmark

In [None]:
from scipy.spatial.transform import Rotation

center = fixed_landmark.mean(axis=0)
noise = np.random.randn(*fixed_landmark.shape) * 0.1
rotation = Rotation.random().as_matrix()
scaling = np.diag([0.8, 1, 1.2])
moving_landmark = (fixed_landmark - center) @ scaling @ rotation + center + noise
moving_landmark

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px

def plot_point_sets_3d(point_sets):
    df = pd.DataFrame()
    for data, label in point_sets:
        df_cur = pd.DataFrame(data, columns=['x', 'y', 'z'])
        df_cur['label'] = label
        df = df.append(df_cur, ignore_index=True)
    
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='label')
    return fig

In [None]:
plot_point_sets_3d([
    (fixed_landmark, 'fixed landmark'),
    (moving_landmark, 'moving landmark')
])

## Landmark Registration

In [None]:
import abakit.registration as reg

import torch
torch.set_default_dtype(torch.float64)

registration = reg.LandmarkRegistration(
    fixed_landmark,
    moving_landmark
)

transform = reg.LandmarkAffineTransform()
registration.set_transform(transform)
registration.set_similarity(reg.MSESimilarity())
registration.set_optimizer_class(torch.optim.Adam)

registration.run()

In [None]:
moving_landmark_transformed = transform.transform_numpy(moving_landmark)
moving_landmark_transformed

In [None]:
plot_point_sets_3d([
    (fixed_landmark, 'fixed landmark'),
    (moving_landmark_transformed, 'moving landmark transformed')
])

## Transformation Parameters

In [None]:
a_mat = transform.get_linear_matrix()
a_mat

In [None]:
t = transform.get_translation()
t

In [None]:
moving_landmark @ a_mat + t

In [None]:
moving_landmark_transformed