## linear operator base

In [1]:
from scipy.sparse import csr_matrix
offsets = csr_matrix([[1, 0, 2], [0, -1, 0], [0, 0, 3]])
print(offsets)

<Compressed Sparse Row sparse matrix of dtype 'int64'
	with 4 stored elements and shape (3, 3)>
  Coords	Values
  (0, 0)	1
  (0, 2)	2
  (1, 1)	-1
  (2, 2)	3


In [None]:
import numpy as np
from scipy.sparse.linalg import LinearOperator
def mv(v):
    return np.array([2*v[0], 3*v[1]])

A = LinearOperator((2,2), matvec=mv)

print(A.matvec(np.ones(2)))

print(A*np.ones(2))

print(A@np.ones(2))

[2. 3.]
[2. 3.]
[2. 3.]


In [None]:
from functools import reduce
product = lambda c: reduce(lambda a, b: a * b, c)

## scalars_nn

In [None]:
import torch
x=torch.rand((2,3,4))

print(x)

scalars = torch.einsum('bix,bjx->bij', x, x)
print(scalars)



tensor([[[0.1739, 0.5695, 0.2003, 0.4404],
         [0.5501, 0.7865, 0.0176, 0.8035],
         [0.9046, 0.3442, 0.2906, 0.8839]],

        [[0.2787, 0.4166, 0.2794, 0.6750],
         [0.9659, 0.0798, 0.3630, 0.0471],
         [0.2742, 0.3525, 0.6014, 0.9480]]])
tensor([[[0.5886, 0.9009, 0.8008],
         [0.9009, 1.5671, 1.4837],
         [0.8008, 1.4837, 1.8026]],

        [[0.7848, 0.4356, 1.0311],
         [0.4356, 1.0733, 0.5559],
         [1.0311, 0.5559, 1.4597]]])


In [None]:
B1=x[0,:,:]
B2=x[1,:,:]
M1=B1@B1.T
M2=B2@B2.T
print(M1)
print(M2)

B=torch.stack((M1,M2))

tensor([[0.5886, 0.9009, 0.8008],
        [0.9009, 1.5671, 1.4837],
        [0.8008, 1.4837, 1.8026]])
tensor([[0.7848, 0.4356, 1.0311],
        [0.4356, 1.0733, 0.5559],
        [1.0311, 0.5559, 1.4597]])
tensor([[[0.5886, 0.9009, 0.8008],
         [0.9009, 1.5671, 1.4837],
         [0.8008, 1.4837, 1.8026]],

        [[0.7848, 0.4356, 1.0311],
         [0.4356, 1.0733, 0.5559],
         [1.0311, 0.5559, 1.4597]]])


In [None]:
G = torch.diag(-torch.ones(4))
G[0,0] = 1
print(G.unsqueeze(0))
print(x)
G = torch.einsum('bix,bxj->bij', x, G.unsqueeze(0))
# G = torch.einsum('cix,cxj->cij', x, G.unsqueeze(0))
print(G)

print()
scalars = torch.einsum('bij,bkj->bik', G, x)
print(scalars)

# simplified version: since the matrix is symmetric, take the upper trianglar part, flatten it
scalars = torch.triu(scalars).view(-1, 3**2)
# print(scalars)
# print(torch.nonzero(scalars[0]))
scalars = scalars[:, torch.nonzero(scalars[0]).squeeze(-1)]
print(scalars)

tensor([[[ 1.,  0.,  0.,  0.],
         [ 0., -1.,  0.,  0.],
         [ 0.,  0., -1.,  0.],
         [ 0.,  0.,  0., -1.]]])
tensor([[[0.1739, 0.5695, 0.2003, 0.4404],
         [0.5501, 0.7865, 0.0176, 0.8035],
         [0.9046, 0.3442, 0.2906, 0.8839]],

        [[0.2787, 0.4166, 0.2794, 0.6750],
         [0.9659, 0.0798, 0.3630, 0.0471],
         [0.2742, 0.3525, 0.6014, 0.9480]]])
tensor([[[ 0.1739, -0.5695, -0.2003, -0.4404],
         [ 0.5501, -0.7865, -0.0176, -0.8035],
         [ 0.9046, -0.3442, -0.2906, -0.8839]],

        [[ 0.2787, -0.4166, -0.2794, -0.6750],
         [ 0.9659, -0.0798, -0.3630, -0.0471],
         [ 0.2742, -0.3525, -0.6014, -0.9480]]])

tensor([[[-0.5281, -0.7096, -0.4862],
         [-0.7096, -0.9619, -0.4884],
         [-0.4862, -0.4884, -0.1660]],

        [[-0.6295,  0.1027, -0.8783],
         [ 0.1027,  0.7926, -0.0263],
         [-0.8783, -0.0263, -1.3094]]])
tensor([[-0.5281, -0.7096, -0.4862,  0.0000, -0.9619, -0.4884,  0.0000,  0.0000,
         -0.

In [26]:
G = torch.diag(-torch.ones(4))
G[0,0] = 1

M1=x[0,:,:]@G@x[0,:,:].T
M2=x[1,:,:]@G@x[1,:,:].T

print(torch.stack((M1,M2)))

tensor([[[-0.5281, -0.7096, -0.4862],
         [-0.7096, -0.9619, -0.4884],
         [-0.4862, -0.4884, -0.1660]],

        [[-0.6295,  0.1027, -0.8783],
         [ 0.1027,  0.7926, -0.0263],
         [-0.8783, -0.0263, -1.3094]]])


In [34]:
N=x.shape[0]
scalars = torch.einsum('bik,bjl->bijkl', x, x) #[N, n, n, dim, dim]
print(scalars.shape)
print(scalars)

torch.Size([2, 3, 3, 4, 4])
tensor([[[[[3.0230e-02, 9.9015e-02, 3.4818e-02, 7.6565e-02],
           [9.9015e-02, 3.2432e-01, 1.1404e-01, 2.5078e-01],
           [3.4818e-02, 1.1404e-01, 4.0102e-02, 8.8185e-02],
           [7.6565e-02, 2.5078e-01, 8.8185e-02, 1.9392e-01]],

          [[9.5643e-02, 1.3675e-01, 3.0546e-03, 1.3970e-01],
           [3.1327e-01, 4.4791e-01, 1.0005e-02, 4.5757e-01],
           [1.1016e-01, 1.5750e-01, 3.5182e-03, 1.6090e-01],
           [2.4224e-01, 3.4635e-01, 7.7365e-03, 3.5382e-01]],

          [[1.5728e-01, 5.9850e-02, 5.0527e-02, 1.5369e-01],
           [5.1516e-01, 1.9604e-01, 1.6550e-01, 5.0339e-01],
           [1.8115e-01, 6.8934e-02, 5.8196e-02, 1.7701e-01],
           [3.9835e-01, 1.5159e-01, 1.2797e-01, 3.8925e-01]]],


         [[[9.5643e-02, 3.1327e-01, 1.1016e-01, 2.4224e-01],
           [1.3675e-01, 4.4791e-01, 1.5750e-01, 3.4635e-01],
           [3.0546e-03, 1.0005e-02, 3.5182e-03, 7.7365e-03],
           [1.3970e-01, 4.5757e-01, 1.6090e-01, 3