In [22]:
import torch
from torch import nn
import ip_cuda


class DotFunction(torch.autograd.Function):
    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, input_left, metric, input_right):
        outputs = ip_cuda.forward(input_left, metric, input_right)
        new_h = outputs[0]
        # variables = outputs[1:] + [weights]
        variables = [input_left, metric, input_right]
        ctx.save_for_backward(*variables)
        return new_h

    @staticmethod
    def backward(ctx, grad_h):
        grad_input_left, grad_input_right = ip_cuda.backward(
            grad_h[None].contiguous(), *ctx.saved_tensors
        )
        return grad_input_left, None, grad_input_right


class Dot(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.input_left = nn.Parameter(torch.arange(size).float() + 10)
        self.input_right = nn.Parameter(torch.arange(size).float())

    def forward(self, metric):
        # return DotFunction.apply(self.input_left, metric, self.input_right)
        return torch.einsum('i, ij, j ->', self.input_left, metric, self.input_right)

In [23]:
metric = torch.randn(1024, 1024).cuda()

In [24]:
%%timeit -n 128 -r 128
module = Dot(1024).cuda()
module(metric).sum().backward()

396 µs ± 6.13 µs per loop (mean ± std. dev. of 128 runs, 128 loops each)


In [18]:
module = Dot(1024).cuda()
module(metric).sum().backward()
module.input_left.grad

tensor([  0.0000,   1.0463,  -0.8177,  ..., -86.8748, 194.3297, 217.4943],
       device='cuda:0')

In [20]:
%%timeit -n 128 -r 128
module = Dot(1024).cuda()
module(metric).sum().backward()

269 µs ± 5.28 µs per loop (mean ± std. dev. of 128 runs, 128 loops each)


In [21]:
module = Dot(1024).cuda()
module(metric).sum().backward()
module.input_left.grad

tensor([  0.0000,   1.0463,  -0.8177,  ..., -86.8748, 194.3297, 217.4943],
       device='cuda:0')

In [25]:
from pykeops.torch import LazyTensor

In [42]:
A = torch.randn(10, 100000, 1, 500)
B = torch.randn(10000, 500)


In [43]:
A_i = LazyTensor(A[:, :, None, :])
B_j = LazyTensor(B[None, None, :, :])

print(A_i.shape, B_j.shape)

(10, 100000, 1, 1, 500) (1, 1, 10000, 500)


In [44]:
from pykeops.numpy import LazyTensor
import numpy as np

A = np.random.randn(64, 8)
B = np.random.randn(1024, 32, 8)

# Assuming A and B are your numpy arrays or torch tensors with the specified shapes
A_lazy = LazyTensor(A[:, None, :, None])  # Reshape A to [64, 1, 8, 1] for broadcasting
B_lazy = LazyTensor(B[:, :, None, :])  # Reshape B to [1024, 32, 1, 8] for broadcasting

# Perform matrix multiplication
# The operation implicitly sums over the last dimension of A and B, which is the dot product
result = (A_lazy * B_lazy).sum(dim=-1)  # Summing over the last dimension (dim=3)

# The result here is a LazyTensor. To get a NumPy array or PyTorch tensor, use .numpy() or .tensor() respectively
result_np = result.numpy()  # Assuming you want a numpy array


ValueError: Incompatible batch dimensions: (64,) and (1024,).

In [61]:
a = torch.randn(32, 1024, 8)
b = torch.randn(1, 64, 8)
# Compute the inner product
result = torch.matmul(a, 


RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 1

In [57]:
a[..., None, :].shape

torch.Size([32, 1024, 1, 8])

In [55]:
b[..., None].shape

torch.Size([32, 32, 8, 1])

In [62]:
a = torch.randn(1024, 8)
b = torch.randn(64, 8)

In [67]:
a[:, :, None].shape

torch.Size([1024, 8, 1])

In [68]:
a @ b.T

tensor([[-1.6063, -0.4657, -0.3143,  ..., -2.0547, -1.8664,  0.1886],
        [-2.0564, -1.6379, -0.2408,  ..., -2.5238, -1.0698,  1.7206],
        [ 4.3503,  3.2547,  1.9063,  ...,  6.3056,  0.8429, -1.9911],
        ...,
        [-7.5581, -7.3256,  2.4146,  ..., -0.4229,  6.6804,  0.5273],
        [ 2.8416, -1.0463,  0.2835,  ..., -2.8278,  1.5270,  6.4478],
        [ 4.2925,  0.2785,  1.7750,  ...,  2.6925,  3.6638,  7.7907]])

In [69]:
M = torch.randn(8, 8, 8)

In [76]:
((a @ M.view(8, 64)).view(1024 * 8, 8) @ b.T).view(1024, 8, 64)

tensor([[[ 1.3855e-01,  1.5969e+00, -2.4437e+00,  ..., -6.4465e-02,
          -8.2941e-01, -4.0274e+00],
         [-3.8261e+00, -3.8148e+00, -8.3644e-01,  ...,  1.2726e+00,
           6.9474e+00, -3.7049e+00],
         [-1.0028e+01, -9.0609e+00,  2.7482e+00,  ..., -7.5631e+00,
           4.4813e+00,  9.3730e-01],
         ...,
         [-3.8833e+00, -4.0569e+00,  5.0064e+00,  ..., -3.1980e+00,
           1.4440e+00,  5.3010e+00],
         [-3.5084e+00, -9.3427e-01,  3.5669e+00,  ..., -3.1311e+00,
          -2.9810e+00,  7.4572e-01],
         [-3.4574e+00,  2.6998e+00,  5.5797e+00,  ...,  6.4986e-01,
           2.7864e+00,  9.6783e-01]],

        [[-2.0395e+00,  2.2895e+00, -1.5884e+01,  ..., -2.9234e+00,
          -1.8282e+01, -1.7510e+01],
         [-2.0357e+01, -1.2736e+01, -2.2818e+00,  ..., -9.4535e+00,
           2.3496e+00, -1.7532e+01],
         [-5.8222e+00, -5.7164e+00,  1.7074e+00,  ..., -5.8486e+00,
           8.9439e-01, -2.8988e+00],
         ...,
         [-4.2642e+00, -6

In [100]:
A = torch.randn(1024, 8)
B = torch.randn(1024, 8)
M = torch.randn(8, 8, 8)

In [108]:
result1 = torch.einsum('bi, ijk, bk -> bj', A, M, B)

In [109]:
result1.shape

torch.Size([1024, 8])

In [106]:
result2 = ((A @ M.view(8, 64)).view(1024 * 8, 8) @ B.T)

result2.shape
# .view(1024, 8, 64)

torch.Size([8192, 1024])

In [97]:
left = (A @ M.view(8, 64)).view(1024, 8, 8) 
torch.allclose(left, torch.einsum('bi, ijk -> bjk', A, M))

True

In [2]:

import torch


In [1]:
left = (A @ M.view(8, 64)).view(1024 * 8, 8)


NameError: name 'A' is not defined

In [88]:
# %%timeit -n 256 -r 256
A = torch.randn(32, 32, 8, device='cuda', requires_grad=True)
B = torch.randn(32, 32, 8, device='cuda', requires_grad=True)
M = torch.randint(0, 2, (8, 8, 8), device='cuda').float()
result1 = ((A.view(-1, 8) @ M.view(8, 64)).view(1024 * 8, 8) @ B.view(-1, 8).T).view(1024, 8, 512).transpose(1, 2)
loss = result1.sum()
loss.backward()
torch.cuda.synchronize()

RuntimeError: shape '[1024, 8, 512]' is invalid for input of size 8388608

In [1]:
import torch

def geometric_product(left, cayley, right):
    assert left.shape == right.shape
    shape = left.shape
    left = left.view(-1, shape[-1])
    right = right.view(-1, shape[-1])
    result = left.matmul(cayley.view(cayley.shape[0], -1)).view(-1, cayley.shape[1], cayley.shape[2]).matmul(right[..., None]).squeeze(-1)
    return result.view(*shape[:-1], shape[-1])

In [2]:
%%timeit -n 64 -r 64
A = torch.randn(32, 32, 32, 32, 8, device='cuda', requires_grad=True)
B = torch.randn(32, 32, 32, 32, 8, device='cuda', requires_grad=True)
M = torch.randint(0, 2, (8, 8, 8), device='cuda').float()
result1 = geometric_product(A, M, B)
loss = result1.sum()
loss.backward()
torch.cuda.synchronize()


The slowest run took 5.03 times longer than the fastest. This could mean that an intermediate result is being cached.
6.92 ms ± 3.22 ms per loop (mean ± std. dev. of 64 runs, 64 loops each)


In [3]:
%%timeit -n 64 -r 64
A = torch.randn(32, 32, 32, 32, 8, device='cuda', requires_grad=True)
B = torch.randn(32, 32, 32, 32, 8, device='cuda', requires_grad=True)
M = torch.randint(0, 2, (8, 8, 8), device='cuda').float()
result2 = torch.einsum('b...i, ijk, b...k -> b...j', A, M, B)
loss = result2.sum()
loss.backward()
torch.cuda.synchronize()

6.55 ms ± 22.8 µs per loop (mean ± std. dev. of 64 runs, 64 loops each)


In [69]:
%%timeit -n 256 -r 256
A = torch.randn(32, 32, 8, device='cuda', requires_grad=True)
B = torch.randn(16, 32, 8, device='cuda', requires_grad=True)
M = torch.randint(0, 2, (8, 8, 8), device='cuda').float()
result2 = torch.einsum('b...i, ijk, cak -> b...caj', A, M, B).view(result1.shape)
loss = result2.sum()
loss.backward()
torch.cuda.synchronize()

591 µs ± 6.81 µs per loop (mean ± std. dev. of 256 runs, 256 loops each)


In [66]:
result1.shape, result2.shape

(torch.Size([1024, 512, 8]), torch.Size([1024, 512, 8]))

In [61]:
result1 - result2.view(result1.shape)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [None]:
x = torch.randn(3)
