In [1]:
import torch
import mytorch
import mytorch.math as math
from matplotlib import pyplot as plt

import torchgroup

In [2]:
class LieGroup(mytorch.autograd.Function):
    def __init__(self) -> None:
        super().__init__()
    
    def _create_impl(self, pade_order, dtype=None, device=None):
        assert pade_order>=2
        # pade coefficient of (1-e^{-x})/x
        k = torch.arange(1,2*pade_order,dtype=dtype)
        c = -torch.cumprod(-1/k,-1)
        a1,b1 = math.pade(c,pade_order,pade_order)
        # pade coefficient of e^{-x}
        k = torch.arange(2*pade_order-1,dtype=dtype)
        k[0] = 1
        c = -torch.cumprod(-1/k,-1)
        a2,b2 = math.pade(c,pade_order,pade_order)
        
        c = torch.stack(
            [
                torch.stack([a1,b1],dim=-1),
                torch.stack([a2,b2],dim=-1)
            ],
            dim=-1
        ).to(device=device)
        
        class impl(torch.autograd.Function):
            @staticmethod
            def forward(ctx:torch.autograd.function.FunctionCtx, w:torch.Tensor):
                al = self.algebra(w.detach())
                y = torch.matrix_exp(al)
                ctx.save_for_backward(y,w.detach())
                return y
            
            @staticmethod
            def backward(ctx:torch.autograd.function.FunctionCtx, dy:torch.Tensor):
                y,w = ctx.saved_tensors
                ad = self.adjoint(w)
                e = torch.eye(ad.size(-1),dtype=ad.dtype,device=ad.device)
                c_ = c[(...,)+(None,)*ad.ndim]
                # scaling
                n = torch.log2(ad.abs().max()*ad.size(-1)).ceil().int().maximum(torch.tensor(0))
                ad = ad/(2**n)
                # pade approximation
                r = c_[0]*e+c_[1]*ad
                p = ad
                for i in range(2,pade_order):
                    p = ad@p
                    r = r+c_[i]*p
                r = r[0]@torch.inverse(r[1])
                # squaring
                for _ in range(n):
                    r[0] = r[0]@(e+r[1])/2
                    r[1] = r[1]@r[1]
                # chain rule
                dw = self.derivative(y, dy.conj())@r[0]
                # dw = torch.einsum('...ij,...ik,...mn,mkj->...n',dy,y,r,self.al)
                return dw.conj()
            
        return impl.apply
            
    def algebra(self, w:torch.Tensor) -> torch.Tensor:
        raise NotImplementedError(f"LieGroup [{type(self).__name__}] is missing the required \"algebra\" function")
    
    def adjoint(self, w:torch.Tensor) -> torch.Tensor:
        raise NotImplementedError(f"LieGroup [{type(self).__name__}] is missing the required \"adjoint\" function")
    
    def derivative(self, y:torch.Tensor, dy:torch.Tensor) -> torch.Tensor:
        raise NotImplementedError(f"LieGroup [{type(self).__name__}] is missing the required \"derivative\" function")



In [5]:
class SO(torchgroup.lie.LieGroup):
    def __init__(self, n:int, device=None) -> None:
        super().__init__()
        self.mdim = n
        self.gdim = n*(n-1)//2
        
        al = torch.zeros(self.gdim,self.mdim,self.mdim,dtype=torch.int)
        self.__index_al = torch.zeros(self.mdim,self.mdim,dtype=torch.long,device=device)
        k = 0
        for i in range(1,n):
            for j in range(i):
                al[k,i,j] = 1
                al[k,j,i] = -1
                self.__index_al[i,j] = k
                self.__index_al[j,i] = k
                k += 1
        
        self.__coef_al = al.sum(0).to(dtype=torch.double,device=device)
        self.__index_al = self.__index_al.flatten()
        
        ad = torch.zeros(self.gdim,self.gdim,self.gdim,dtype=torch.int)
        self.__index_ad = torch.zeros(self.gdim,self.gdim,dtype=torch.long,device=device)
        for i in range(self.gdim):
            ad[i] = self.vectorize(al[i]@al-al@al[i]).transpose(-2,-1)
            self.__index_ad[ad[i]!=0] = i
        
        self.__coef_ad = ad.sum(0).to(dtype=torch.double,device=device)
        self.__index_ad = self.__index_ad.flatten()
        
        self._set(self._create_impl(dtype=torch.double,device=device))
        
    def vectorize(self, x:torch.Tensor):
        idx0,idx1 = mytorch.count_to_index(torch.arange(self.mdim))
        return x[...,idx0,idx1]
            
    def algebra(self, w:torch.Tensor):
        return w[...,self.__index_al].unflatten(-1,(self.mdim,self.mdim))*self.__coef_al
        # return torch.einsum('...i,ijk->...jk',w,self.al)
    
    def adjoint(self, w:torch.Tensor):
        return w[...,self.__index_ad].unflatten(-1,(self.gdim,self.gdim))*self.__coef_ad
        # return torch.einsum('...i,ijk->...jk',w,self.ad)
    
    def derivative(self, y:torch.Tensor, dy:torch.Tensor):
        return torch.zeros(
            y.size()[:-2]+(self.gdim,), 
            dtype=y.dtype, 
            device=y.device
        ).index_add_(
            -1,
            self.__index_al,
            ((y.transpose(-2,-1)@dy)*self.__coef_al).flatten(-2)
        )
    

In [6]:
class GW(torchgroup.lie.LieGroup):
    def __init__(self, device=None) -> None:
        super().__init__()
        self._set(self._create_impl(device=device))
        
    def coef_al(self, dim:int, device=None):
        return -1j*(torch.abs(torch.arange(dim,device=device)[:,None]-torch.arange(dim,device=device)[None,:])<=(dim-1)//2)*\
            torch.arange(-dim//2+1,dim//2+1,device=device)
    
    def index_al(self, dim:int, device=None):
        i = torch.nn.functional.pad(
                torch.arange(
                    dim,
                    device=device,
                    dtype=torch.long   
                ),
                [(dim-1)//2,(dim-1)//2]
            )
        return torch.as_strided(
            i,
            size=(dim,dim),
            stride=(i.stride(-1),i.stride(-1))
        ).flip([-1]).flatten()
            
    def algebra(self, w:torch.Tensor):
        dim = w.shape[-1]
        w_ = torch.nn.functional.pad(w,[(dim-1)//2,(dim-1)//2])
        return torch.as_strided(
            w_,
            size=w_.shape[:-1]+(dim,dim),
            stride=w_.stride()[:-1]+(w_.stride(-1),w_.stride(-1))
        ).flip([-1]).mul(
            -1j*torch.arange(-dim//2+1,dim//2+1,device=w.device)
        )
        
    def adjoint(self, w:torch.Tensor):
        dim = w.shape[-1]
        w_ = torch.nn.functional.pad(w,[(dim-1)//2,(dim-1)//2])
        c = torch.arange(-3*(dim-1)//2,3*(dim-1)//2+1,device=w.device)
        return torch.as_strided(
                w_,
                size=w_.shape[:-1]+(dim,dim),
                stride=w_.stride()[:-1]+(w_.stride(-1),w_.stride(-1))
            ).flip([-1])*\
            torch.as_strided(
                -1j*c,
                size=(dim,dim),
                stride=(c.stride(-1),2*c.stride(-1))
            ).flip([-2])
    
    def derivative(self, y:torch.Tensor, dy:torch.Tensor):
        return torch.zeros(
            y.size()[:-2]+(y.size(-1),), 
            dtype=y.dtype, 
            device=y.device
        ).index_add_(
            -1,
            self.index_al(y.size(-1), device=y.device),
            ((y.transpose(-2,-1)@dy)*self.coef_al(y.size(-1), device=y.device)).flatten(-2)
        )
    
    def element_(self, w: torch.Tensor) -> torch.Tensor:
        return torch.matrix_exp(self.algebra(w))
    

In [7]:
n = 16
device = 'cpu'
g = SO(n,device=device)
w = torch.randn(g.gdim,requires_grad=True,dtype=torch.double,device=device)*1e+3
torch.autograd.gradcheck(g.__call__,w)

True

In [101]:
n = 8
device = 'cuda'
g = GW(device=device)
w = torch.randn(n,requires_grad=True,dtype=torch.double,device=device) +0j
w[1:] = w[1:] + 1j*torch.randn(n-1,requires_grad=True,dtype=torch.double,device=device)
w[1:] = w[1:]/torch.arange(1,n,dtype=torch.double,device=device)
w[4:] = 0
w = torch.cat([w[1:].flip([-1]).conj(),w])*0

try:
    torch.autograd.gradcheck(g.__call__,w)
except RuntimeError as e:
    print(e)

While considering the imaginary part of complex outputs only, Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j],
        [0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j],
        [0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j],
        ...,
        [0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j],
        [0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j],
        [0.+0.j, 0.+0.j, 0.+0.j,  ..., 0.+0.j, 0.+0.j, 0.+0.j]],
       device='cuda:0', dtype=torch.complex128)
analytical:tensor([[nan+nanj, nan+nanj, nan+nanj,  ..., nan+nanj, nan+nanj, nan+nanj],
        [nan+nanj, nan+nanj, nan+nanj,  ..., nan+nanj, nan+nanj, nan+nanj],
        [nan+nanj, nan+nanj, nan+nanj,  ..., nan+nanj, nan+nanj, nan+nanj],
        ...,
        [nan+nanj, nan+nanj, nan+nanj,  ..., nan+nanj, nan+nanj, nan+nanj],
        [nan+nanj, nan+nanj, nan+nanj,  ..., nan+nanj, nan+nanj, nan+nanj],
        [nan+nanj, nan

In [8]:
n = 64
batch = 32
device = 'cuda'
dtype = torch.complex64
g = GW(device=device)
x = torch.randn(batch,2*n+1,10,dtype=dtype,device=device)
w_ = torch.randn(batch,n,device=device,dtype=dtype)/torch.arange(1,n+1,device=device)*1e-3
A = g(torch.cat([w_.flip([-1]).conj(),torch.zeros_like(w_[...,:1]),w_],dim=-1))
y = A@x

In [9]:
w = torch.nn.Parameter(torch.zeros(batch,n,dtype=dtype,device=device))
optimizer = torch.optim.Adam([w])

# w[1:] = w[1:] + 1j*torch.randn(n-1,requires_grad=True,dtype=torch.double,device=device)
# w[1:] = w[1:]/torch.arange(1,n,dtype=torch.double,device=device)
for i in range(10000):
    w_ = w/torch.arange(1,n+1,device=device)*1e-2
    A = g(torch.cat([w_.flip([-1]).conj(),torch.zeros_like(w_[...,:1]),w_],dim=-1))
    loss = torch.dist(y,A@x)
    print(f'{loss}\r', end='')
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

0.0079725272953510285

KeyboardInterrupt: 