# Developing Implementation of Spherical Interpolation

In [4]:
import torch
import torch.nn as nn
import plotly.express as px

In [115]:
class HypersphereLERP(nn.Module):
    # this approximates spherical interpolation (slerp) as linear interpolation (lerp)
    # SLERP(a, b; alpha) = sin((1-alpha) * theta) / sin(theta) * a + sin(alpha * theta) / sin(theta) * b,
    # where theta = angle between a and b = arccos(<a, b>), and alpha is the interpolation weight in [0, 1]
    # here, following the nGPT paper, we approximate this by LERP(a, b; alpha) = a + alpha * (b - a) = (1 - alpha) * a + alpha * b
    # in the limit as theta -> 0, SLERP(a, b; alpha) -> LERP(a, b; alpha)

    def __init__(self, dim, lerp_scale=None, lerp_init=1.0, lerp_weight_constraint='none'):
        super().__init__()

        self.dim = dim
        self.lerp_init = lerp_init
        self.lerp_scale = lerp_scale if lerp_scale is not None else self.dim ** 0.5
        self.lerp_weight = nn.Parameter(torch.ones(self.dim) * self.lerp_scale, requires_grad=True)
        self.forward_lerp_weight_scale = self.lerp_init / self.lerp_scale

        # if normalize_lerp_weight, then normalize lerp_weight to [0,1] using sigmoid
        # NOTE: in nGPT paper, they don't normalize interpolation weight alpha
        # (which is a bit confusing to me, since operation is not interpretable and may be strongly biased to)
        self.lerp_weight_constraint = lerp_weight_constraint # whether to normalize lerp_weight to [0,1]
        assert lerp_weight_constraint in ['none', 'sigmoid', 'abs', 'clamp']
        self.lerp_weight_constraint_fn = {
            'none': lambda x: x,
            'sigmoid': lambda x: x.sigmoid(),
            'abs': lambda x: torch.abs(x),
            'clamp': lambda x: x.clamp(0, 1),
        }.get(lerp_weight_constraint)

    def forward(self, x, y):
        # normalize/project to hypersphere
        # typically (e.g., in ResNet architecture with this resstream method), x will already be normalized to unit norm
        x, y = torch.nn.functional.normalize(x, p=2, dim=-1), torch.nn.functional.normalize(y, p=2, dim=-1)

        interpolation_weight = self.lerp_weight_constraint_fn(self.lerp_weight * self.forward_lerp_weight_scale)
        x = x + interpolation_weight * (y - x)
        x = torch.nn.functional.normalize(x, p=2, dim=-1)

        return x

class HypersphereSLERP(nn.Module):
    # this computes spherical interpolation (slerp) between two vectors on the unit-norm hypersphere
    # SLERP(a, b; alpha) = sin((1-alpha) * theta) / sin(theta) * a + sin(alpha * theta) / sin(theta) * b,

    # unlike HypersphereLERP, this does not use a linear approximation, and strictly enforces alpha to be in [0,1]

    def __init__(self, dim, single_weight=True):
        super().__init__()

        self.dim = dim
        self.single_weight = single_weight

        # if single_weight, then use a single scalar weight for all dimensions;
        # otherwise, use a separate weight for each dimension
        self.slerp_weight = nn.Parameter(torch.ones(1) if single_weight else torch.ones(self.dim), requires_grad=True)
        # what is geometric interpretation of single_weight = False?

    def forward(self, x, y):
        # x, y: [batch_size, ..., dim]

        # normalize to unit norm
        x, y = torch.nn.functional.normalize(x, p=2, dim=-1), torch.nn.functional.normalize(y, p=2, dim=-1)
        cos_theta = (x * y).sum(dim=-1, keepdim=True) # shape: [batch_size, ..., 1]
        theta = torch.acos(cos_theta) # shape: [batch_size, ..., 1]
        sin_theta = torch.sin(theta) # shape: [batch_size, ..., 1]


        # sigmoid to ensure map interpolation weight to [0,1]
        alpha = self.slerp_weight.sigmoid() # shape: [1] or [dim]

        x = torch.sin((1 - alpha) * theta) / sin_theta * x + torch.sin(alpha * theta) / sin_theta * y
        # shape: [batch_size, ..., dim], where each vector is interpolated between x and y
        # norm(x, dim=-1) = 1 (i.e., preserves unit-norm after interpolation)

        # if not single weight, this is not strictly spherical interpolation, and may not preserve unit-norm
        if not self.single_weight:
            x = torch.nn.functional.normalize(x, p=2, dim=-1)

        return x

class AdaptiveHypersphereSLERP(nn.Module):
    # this computes spherical interpolation (slerp) between two vectors on the unit-norm hypersphere
    # SLERP(a, b; alpha) = sin((1-alpha) * theta) / sin(theta) * a + sin(alpha * theta) / sin(theta) * b,
    # interpolation parameter alpha is instance dependent, making it adaptive
    # This can enable richer dynamics in the model, and can be used to implement a natural type of gating mechanism

    # unlike HypersphereLERP, this does not use a linear approximation, and strictly enforces alpha to be in [0,1]

    def __init__(self, dim, single_weight=True):
        super().__init__()

        self.dim = dim
        self.single_weight = single_weight

        # linear map from y to interpolation weight alpha
        # if single_weight, then use a single scalar weight for all dimensions;
        # otherwise, use a separate weight for each dimension
        self.slerp_weight_map = nn.Linear(dim, 1) if single_weight else nn.Linear(dim, dim)

    def forward(self, x, y):
        # x, y: [batch_size, ..., dim]

        # normalize to unit norm
        x, y = torch.nn.functional.normalize(x, p=2, dim=-1), torch.nn.functional.normalize(y, p=2, dim=-1)
        cos_theta = (x * y).sum(dim=-1, keepdim=True) # shape: [batch_size, ..., 1]
        theta = torch.acos(cos_theta) # shape: [batch_size, ..., 1]
        sin_theta = torch.sin(theta) # shape: [batch_size, ..., 1]


        # sigmoid to ensure map interpolation weight to [0,1]
        alpha = self.slerp_weight_map(y).sigmoid() # shape: [1] or [dim]

        x = torch.sin((1 - alpha) * theta) / sin_theta * x + torch.sin(alpha * theta) / sin_theta * y
        # shape: [batch_size, ..., dim], where each vector is interpolated between x and y
        # norm(x, dim=-1) = 1 (i.e., preserves unit-norm after interpolation)

        # if not single weight, this is not strictly spherical interpolation, and may not preserve unit-norm
        if not self.single_weight:
            x = torch.nn.functional.normalize(x, p=2, dim=-1)

        return x

    # TODO: IDEA: alpha interpolation weight can be learnable function of x and/or y. This implements a natural type of gating mechanism.

In [100]:
dim = 3

x = torch.randn(1, dim)
y = torch.randn(1, dim)

In [101]:
import numpy as np
import plotly.graph_objects as go


def draw_sphere(fig, opacity=0.5):
    d = np.pi/32

    theta, phi = np.mgrid[0:np.pi+d:d, 0:2*np.pi:d]
    # Convert to Cartesian coordinates
    x = np.sin(theta) * np.cos(phi)
    y = np.sin(theta) * np.sin(phi)
    z = np.cos(theta)
    # print(x.shape, y.shape, z.shape)  # (33, 64) (33, 64) (33, 64)
    points = np.vstack([x.ravel(), y.ravel(), z.ravel()])
    # print(points.shape)  # (3, 2112)
    x, y, z = points
    # print(x.shape, y.shape, z.shape)  # (2112,) (2112,) (2112,)

    fig.add_trace(
        go.Mesh3d(x=x, y=y, z=z, color='lightblue', opacity=opacity, alphahull=0)
    )


In [102]:
import plotly.graph_objects as go
import numpy as np

def plot_vectors_unit_sphere(vecs, names=None):

    # Create the figure
    fig = go.Figure()

    if names is None:
        names = [f'Vector {i}' for i in range(1, len(vecs) + 1)]

    for vec, name in zip(vecs, names):
        # Add the vector
        fig.add_trace(go.Scatter3d(
            x=[0, vec[0]], y=[0, vec[1]], z=[0, vec[2]],
            mode='lines+markers',
            marker=dict(size=4),
            line=dict(width=5),
            name=name
        ))

    draw_sphere(fig, opacity=0.5)

    # Set the layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[-1, 1]),
            yaxis=dict(range=[-1, 1]),
            zaxis=dict(range=[-1, 1]),
            aspectmode='cube'
        ),
        title='Vectors on the unit sphere'
    )

    # Show the plot
    fig.show()

In [103]:
x = [1, 0, 0]  # Example unit vector
y = [0, 1, 0]  # Example unit vector

plot_vectors_unit_sphere([x, y], names=['x', 'y'])

In [108]:
x = torch.nn.functional.normalize(torch.randn(1, dim), p=2, dim=-1)
y = torch.nn.functional.normalize(torch.randn(1, dim), p=2, dim=-1)

## Hypersphere LERP

Computes linear interpolation between vectors, then projects them onto unit sphere: $\mathrm{Normalize}(x + \alpha (y - x))$

In [109]:
hyperspherelerp = HypersphereLERP(dim=dim, lerp_scale=1.0, lerp_init=1.0, lerp_weight_constraint='none')
lerp_interp = hyperspherelerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), lerp_interp[0].detach().numpy()], names=['x', 'y', 'lerp_interp'])

In [110]:
hyperspherelerp = HypersphereLERP(dim=dim, lerp_scale=1.0, lerp_init=1.0, lerp_weight_constraint='sigmoid')
lerp_interp = hyperspherelerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), lerp_interp[0].detach().numpy()], names=['x', 'y', 'lerp_interp'])

In [111]:
hyperspherelerp = HypersphereLERP(dim=dim, lerp_scale=1.0, lerp_init=1.0, lerp_weight_constraint='abs')
lerp_interp = hyperspherelerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), lerp_interp[0].detach().numpy()], names=['x', 'y', 'lerp_interp'])

## Hypersphere SLERP:

Returns $SLERP(x, y; \alpha) = \frac{\sin((1-\alpha) \theta_{xy})}{\sin(\theta_{xy})} x + \frac{\sin(\alpha \theta_{xy})}{\sin(\theta_{xy})} y$, where $\alpha \in (0,1)$ is the interpolation parameter.

In [None]:
hypersphereslerp = HypersphereSLERP(dim=dim, single_weight=True)
slerp_interp = hypersphereslerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), slerp_interp[0].detach().numpy()], names=['x', 'y', 'slerp_interp'])

In [113]:
hypersphereslerp = HypersphereSLERP(dim=dim, single_weight=False)
slerp_interp = hypersphereslerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), slerp_interp[0].detach().numpy()], names=['x', 'y', 'slerp_interp'])

In [114]:
hypersphereslerp = HypersphereSLERP(dim=dim, single_weight=False)
# random slerp_weight
hypersphereslerp.slerp_weight.data = torch.randn(dim)

slerp_interp = hypersphereslerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), slerp_interp[0].detach().numpy()], names=['x', 'y', 'slerp_interp'])

## AdaptiveHypersphereSLERP

Interpolation weight $\alpha$ is instance-dependent: $\alpha = \mathrm{sigmoid}(\bm{y} \cdot W_\alpha)$

In [116]:
hypersphereslerp = AdaptiveHypersphereSLERP(dim=dim, single_weight=True)

slerp_interp = hypersphereslerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), slerp_interp[0].detach().numpy()], names=['x', 'y', 'slerp_interp'])

In [117]:
hypersphereslerp = AdaptiveHypersphereSLERP(dim=dim, single_weight=False)

slerp_interp = hypersphereslerp(x, y)

plot_vectors_unit_sphere([x[0].numpy(), y[0].numpy(), slerp_interp[0].detach().numpy()], names=['x', 'y', 'slerp_interp'])