In [642]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from numpy.random import default_rng as rg
%matplotlib inline
backend = plt.get_backend()
import jpcm
plt.switch_backend(backend)
cs = jpcm.get('fuyu').resampled(10)
rng = rg(12345)
from tqdm import tqdm

In [318]:
import torch
import torch.optim as optim

# Make a group representation!

In [610]:
def act(x):
    return x #torch.floor(x) #1/(1+torch.exp(-x))
def inv_act(x):
    return x #torch.log(x/(1-x))


# base FCN AE, 4 layer
class AE(torch.nn.Module):
    def __init__(self):
        super(AE,self).__init__()
        self.activation = act##torch.nn.functional.sigmoid #functional.relu #
        self.invactivation = inv_act#torch.nn.functional.sigmoid #functional.relu ##
        p = 12
        self.rep_shape = [2,2]
        self.rep_n = self.rep_shape[0] * self.rep_shape[1]
        self.dim = self.rep_n * 2
        # encoder
        self.linear1e = torch.nn.Linear(2, p)
        self.linear2e = torch.nn.Linear(p, p)
        self.linear3e = torch.nn.Linear(p, p)
        self.linear4e = torch.nn.Linear(p, p)
        self.linear5e = torch.nn.Linear(p, self.dim)
        
        # decoder
        self.linear5d = torch.nn.Linear(self.dim, p)
        self.linear4d = torch.nn.Linear(p, p)
        self.linear3d = torch.nn.Linear(p, p)
        self.linear2d = torch.nn.Linear(p, p)
        self.linear1d = torch.nn.Linear(p, 2)
    
    def forward_e(self, x):
        x = self.linear1e(x)
        x = self.activation(x)
        x = self.linear2e(x)
        x = self.activation(x)
        x = self.linear3e(x)  
        x = self.activation(x)
        x = self.linear4e(x)  
        x = self.activation(x)
        x = self.linear5e(x)
        return x
    
    def forward_d(self, x):
        x = self.linear5d(x)
        x = self.invactivation(x)
        x = self.linear4d(x)
        x = self.invactivation(x)  
        x = self.linear3d(x)
        x = self.invactivation(x)
        x = self.linear2d(x)
        x = self.invactivation(x)
        x = self.linear1d(x)        
        return x
    
    def forward(self, x1, x2):
        x1 = self.forward_e(x1)
        x2 = self.forward_e(x2)
        # complex matrix multiplication
        r1 = x1[:self.rep_n].reshape(self.rep_shape)
        i1 = x1[self.rep_n:].reshape(self.rep_shape)
        r2 = x2[:self.rep_n].reshape(self.rep_shape)
        i2 = x2[self.rep_n:].reshape(self.rep_shape)
        #
        r3 = r1 @ r2 - i1 @ i2
        i3 = r1 @ i2 + i1 @ r2
        #
        x3 = torch.hstack([r3.flatten(),i3.flatten()])
        return self.forward_d(x3)

In [611]:
# mod AE, 4 layer
class AE_M(torch.nn.Module):
    def __init__(self):
        super(AE_M,self).__init__()

        self.activation = torch.nn.functional.sigmoid #functional.relu #
        self.invactivation = torch.nn.functional.sigmoid #functional.relu ##
        self.acty = lambda x: x%2
        self.actz = lambda x: x%6
        p = 12
        l = 6
        
        self.rep_shape = [2,2]
        self.rep_n = self.rep_shape[0] * self.rep_shape[1]
        self.dim= self.rep_n * 2
        # encoder  
        self.linear1e = torch.nn.Linear(2, p)
        self.linear2e = torch.nn.Linear(p, l)
        self.linear3e = torch.nn.Linear(p, p-l)
        self.linear4e = torch.nn.Linear(p, p)
        self.linear5e = torch.nn.Linear(p, self.dim)
        
        # decoder
        self.linear5d = torch.nn.Linear(self.dim, p)
        self.linear4d = torch.nn.Linear(p, l)
        self.linear3d = torch.nn.Linear(p, p-l)
        self.linear2d = torch.nn.Linear(p, p)
        self.linear1d = torch.nn.Linear(p, 2)
    
    def forward_e(self, x):
        x = self.linear1e(x)
        x = self.activation(x)
        y = self.linear2e(x)
        y = self.acty(y)
        z = self.linear3e(x)  
        z = self.actz(z)
        x = self.linear4e(torch.hstack([y,z]))  
        x = self.activation(x)
        x = self.linear5e(x)
        return x
    
    def forward_d(self, x):
        x = self.linear5d(x)
        x = self.invactivation(x)
        y = self.linear4d(x)
        y = self.acty(y)  
        z = self.linear3d(x)
        z = self.actz(z)
        x = self.linear2d(torch.hstack([y,z]))
        x = self.invactivation(x)
        x = self.linear1d(x)        
        return x
    
    def forward(self, x1, x2):
        x1 = self.forward_e(x1)
        x2 = self.forward_e(x2)
        # complex matrix multiplication
        r1 = x1[:self.rep_n].reshape(self.rep_shape)
        i1 = x1[self.rep_n:].reshape(self.rep_shape)
        r2 = x2[:self.rep_n].reshape(self.rep_shape)
        i2 = x2[self.rep_n:].reshape(self.rep_shape)
        #
        r3 = r1 @ r2 - i1 @ i2
        i3 = r1 @ i2 + i1 @ r2
        #
        x3 = torch.hstack([r3.flatten(),i3.flatten()])
        return self.forward_d(x3)

In [612]:
# base CNN AE, 4 layer
class AE_C(torch.nn.Module):
    def __init__(self):
        super(AE_C,self).__init__()

        self.activation = torch.nn.functional.sigmoid #functional.relu #
        self.invactivation = torch.nn.functional.sigmoid #functional.relu ##
        p = 20
        self.rep_shape = [2,2]
        self.rep_n = self.rep_shape[0] * self.rep_shape[1]
        self.dim = self.rep_n * 2
        kernel_size = 5
        # encoder
        self.linear1e = torch.nn.Linear(2, p)
        self.linear2e = torch.nn.Conv1d(1,1, kernel_size=kernel_size, padding='same', bias=True)
        self.linear3e = torch.nn.Conv1d(1,1, kernel_size=kernel_size, padding='same', dilation=2, bias=True)
        self.linear4e = torch.nn.Linear(p, p)
        self.linear5e = torch.nn.Linear(p, self.dim)
        
        # decoder
        self.linear5d = torch.nn.Linear(self.dim, p)
        self.linear4d = torch.nn.Linear(p, p)
        self.linear3d = torch.nn.Conv1d(1,1, kernel_size=kernel_size, padding='same', bias=True)
        self.linear2d = torch.nn.Conv1d(1,1, kernel_size=kernel_size, padding='same', dilation=2, bias=True)
        self.linear1d = torch.nn.Linear(p, 2)
    
    def forward_e(self, x):
        x = self.linear1e(x)
        x = self.activation(x)
        x = self.linear2e(x.reshape(1, 1, -1))
        x = self.activation(x)
        x = self.linear3e(x)  
        x = self.activation(x).reshape(-1)
        x = self.linear4e(x)  
        x = self.activation(x)
        x = self.linear5e(x)
        return x
    
    def forward_d(self, x):
        x = self.linear5d(x)
        x = self.invactivation(x)
        x = self.linear4d(x)
        x = self.invactivation(x)  
        x = self.linear3d(x.reshape(1, 1, -1))
        x = self.invactivation(x)
        x = self.linear2d(x)
        x = self.invactivation(x).reshape(-1)
        x = self.linear1d(x)        
        return x
    
    def forward(self, x1, x2):
        x1 = self.forward_e(x1)
        x2 = self.forward_e(x2)
        # complex matrix multiplication
        r1 = x1[:self.rep_n].reshape(self.rep_shape)
        i1 = x1[self.rep_n:].reshape(self.rep_shape)
        r2 = x2[:self.rep_n].reshape(self.rep_shape)
        i2 = x2[self.rep_n:].reshape(self.rep_shape)
        #
        r3 = r1 @ r2 - i1 @ i2
        i3 = r1 @ i2 + i1 @ r2
        #
        x3 = torch.hstack([r3.flatten(),i3.flatten()])
        return self.forward_d(x3)

In [703]:
def primes():
    for p in [2,3,5,7]: yield p                 # base wheel primes
    gaps1 = [ 2,4,2,4,6,2,6,4,2,4,6,6,2,6,4,2,6,4,6,8,4,2,4,2,4,8 ]
    gaps = gaps1 + [ 6,4,6,2,4,6,2,6,6,4,2,4,6,2,6,4,2,4,2,10,2,10 ] # wheel2357
    def wheel_prime_pairs():
        yield (11,0); bps = wheel_prime_pairs() # additional primes supply
        p, pi = next(bps); q = p * p            # adv to get 11 sqr'd is 121 as next square to put
        sieve = {}; n = 13; ni = 1              #   into sieve dict; init cndidate, wheel ndx
        while True:
            if n not in sieve:                  # is not a multiple of previously recorded primes
                if n < q: yield (n, ni)         # n is prime with wheel modulo index
                else:
                    npi = pi + 1                # advance wheel index
                    if npi > 47: npi = 0
                    sieve[q + p * gaps[pi]] = (p, npi) # n == p * p: put next cull position on wheel
                    p, pi = next(bps); q = p * p  # advance next prime and prime square to put
            else:
                s, si = sieve.pop(n)
                nxt = n + s * gaps[si]          # move current cull position up the wheel
                si = si + 1                     # advance wheel index
                if si > 47: si = 0
                while nxt in sieve:             # ensure each entry is unique by wheel
                    nxt += s * gaps[si]
                    si = si + 1                 # advance wheel index
                    if si > 47: si = 0
                sieve[nxt] = (s, si)            # next non-marked multiple of a prime
            nni = ni + 1                        # advance wheel index
            if nni > 47: nni = 0
            n += gaps[ni]; ni = nni             # advance on the wheel
    for p, pi in wheel_prime_pairs(): yield p   # strip out indexes

In [704]:
class RootCell(torch.nn.Module):
    def __init__(self, n, use_primes=True):
        super().__init__()
        
        if use_primes:
            self.p = torch.from_numpy(np.fromiter(primes(),float,count=n)).float()
        else:
            self.p = torch.from_numpy(np.arange(1,n+1)).float()

    def forward(self, x, w):
        # w_times_x= torch.mm(x, self.weights.t())
        # torch.add(w_times_x, self.bias)  # w times x + b
        roots_r = torch.cos(2*np.pi * x / self.p) * w
        roots_i = torch.sin(2*np.pi * x / self.p) * w
        return roots_r, roots_i      

class PowerCell(torch.nn.Module):
    def __init__(self, n, use_primes=True):
        super().__init__()
        
        if use_primes:
            self.p = torch.from_numpy(np.fromiter(primes(),float,count=n)).float()
        else:
            self.p = torch.from_numpy(np.arange(1,n+1)).float()

    def forward(self, x_r, x_i, w):
        order = torch.abs(2*np.pi / torch.atan2(x_i,x_r))
        cypow = (self.p / order) * w
        return cypow
        
    
# rt = RootCell(12)
# rr,ri = rt(torch.Tensor([1]*12))
# pc = PowerCell(12)
# assert torch.mean(pc(rr,ri)) - 1.0 < 1e-12

In [893]:
class AE_G(torch.nn.Module):
    def __init__(self):
        super(AE_G,self).__init__()
        p = 10
        
        cgroup = torch.ones(p)
        self.weights = torch.nn.Parameter(cgroup)
        
        # base
        self.rt1 = RootCell(p,use_primes=False)
        self.pc1 = PowerCell(p,use_primes=False)
        
        # encoder
        self.linear1a = torch.nn.Linear(2, p)
        self.linear1b = torch.nn.Linear(2, p)
        self.linear1e = torch.nn.Linear(2*p,p)
        # decoder
        self.linear1d = torch.nn.Linear(p, 2)
        
    def forward(self,x1, x2):
        x1 = self.linear1a(x1)
        x2 = self.linear1b(x2)
        x = self.linear1e(torch.hstack([x1,x2]))
        rr,ri = self.rt1(x,w=self.weights)
        mod = self.pc1(rr,ri,w=self.weights)
        out = self.linear1d(mod)
        return out

class AE_G2(torch.nn.Module):
    def __init__(self):
        super(AE_G2,self).__init__()
        p = 10
        
        cgroup = torch.ones(p)
        self.weights = torch.nn.Parameter(cgroup)
        
        # base
        self.rt1 = RootCell(p,use_primes=False)
        self.pc1 = PowerCell(p,use_primes=False)
        
        # encoder
        self.linear1a = torch.nn.Linear(2, p)
        self.linear1b = torch.nn.Linear(2, p) # create powers of unity
        # decoder
        self.linear1d = torch.nn.Linear(p, 2)
        
    def forward(self,x1, x2):
        
        x1 = self.linear1a(x1) # encode
        rr1,ri1 = self.rt1(x1,w=self.weights)
        
        x2 = self.linear1b(x2) # encode
        rr2,ri2 = self.rt1(x2,w=self.weights)
        
        r3 = rr1 @ rr2 - ri1 @ ri2
        i3 = rr1 @ ri2 + ri1 @ rr2
                
        mod = self.pc1(r3,i3,w=self.weights)
        out = self.linear1d(mod)
        return out  
    
class AE_G3(torch.nn.Module):
    def __init__(self):
        super(AE_G3,self).__init__()
        p = 10
        rep = 2
        self.rep_shape=[rep,rep]
        cgroup = torch.ones(p)
        self.weights = torch.nn.Parameter(cgroup)
        
        # base
        self.rt1 = RootCell(p,use_primes=False)
        self.pc1 = PowerCell(p,use_primes=False)
        
        # encoder
        self.linear1 = torch.nn.Linear(2, p)
        self.linear2 = torch.nn.Linear(p,rep**2) # map to matrix rep
        
        # decoder
        self.linear2inv = torch.nn.Linear(rep**2,p)
        self.linear1d = torch.nn.Linear(p, 2)
        
    def forward(self,x1, x2):
        
        x1 = self.linear1(x1) # encode
        rr1,ri1 = self.rt1(x1,w=self.weights)
        rr1 = self.linear2(rr1)
        ri1 = self.linear2(ri1)
        r1 = rr1.reshape(self.rep_shape)
        i1 = ri1.reshape(self.rep_shape)
        
        x2 = self.linear1(x2) # encode
        rr2,ri2 = self.rt1(x2,w=self.weights)
        rr2 = self.linear2(rr2)
        ri2 = self.linear2(ri2)
        r2 = rr2.reshape(self.rep_shape)
        i2 = ri2.reshape(self.rep_shape)
        
        r3 = r1 @ r2 - i1 @ i2 # mult
        i3 = r1 @ i2 + i1 @ r2
                
        r4 = self.linear2inv(r3.flatten()) # decode
        i4 = self.linear2inv(i3.flatten())
        mod = self.pc1(r4,i4,w=self.weights)
        out = self.linear1d(mod)
        return out    
    
class AE_G4(torch.nn.Module):
    def __init__(self):
        super(AE_G4,self).__init__()
        
        p = 10            # possible orders (2-p)
        self.rep = 2           # dimension of C matrix rep
        internal_mult = 2 # number of generators (cyclic)
        
        self.internal_mult = 2 
        
        
        self.rep2 = self.rep**2
        self.rep_shape=[self.rep,self.rep]
        
        cgroup = torch.ones(p)
        ogroup = torch.ones(p)
        self.weights = torch.nn.Parameter(cgroup) # a multiplicative selection of generators
        self.orders = torch.nn.Parameter(ogroup) # the powers of each generator (which would be one unless otherwise necessary for the complex rep)
        
        # base
        self.rt = RootCell(p,use_primes=False)
        self.pc = PowerCell(p,use_primes=False)
        
        # encoder
        self.linear1 = torch.nn.Linear(2, self.internal_mult)
        self.maps = [torch.nn.Linear(p,self.rep2) for _ in range(self.internal_mult)] # map to matrix rep
        
        # decoder
        self.linear2inv = torch.nn.Linear(self.rep2,p)
        self.linear1d = torch.nn.Linear(p, 2)
        
    def generate(self):
        rp = []
        ip = []
        rr, ri = self.rt(self.orders,w=self.weights)
        for i in range(self.internal_mult):
            rr1 = self.maps[i](rr)
            ri1 = self.maps[i](ri)
            r = rr1.reshape(self.rep_shape)
            i = ri1.reshape(self.rep_shape)
            rp.append(r)
            ip.append(i)
        return rp, ip
        

    def encode(self,x,rp,ip):
        x = self.linear1(x)
        ra = torch.eye(self.rep)
        ia = torch.zeros(self.rep_shape)
        for j in range(self.internal_mult):
            for i in range(int(x[j])):
                ra, ia = self.mul(ra,ia,rp[j],ip[j])
        return ra, ia
    
    def decode(self,r,i):
        r4 = self.linear2inv(r.flatten()) # decode
        i4 = self.linear2inv(i.flatten())
        
        mod = self.pc(r4,i4,w=self.weights)
        out = self.linear1d(mod)
        return out
    
    def mul(self,r1,i1,r2,i2):
        r3 = r1 @ r2 - i1 @ i2
        i3 = r1 @ i2 + i1 @ r2
        return r3,i3
        
    def forward(self,x1, x2):
        
        # make generators
        rp,ip = self.generate()
        
        # encode, decode
        
        r1, i1 = self.encode(x1,rp,ip)
        r2, i2 = self.encode(x2,rp,ip)
        r3, i3, = self.mul(r1,i1,r2,i2)
        out = self.decode(r3,i3)               

        return out    
    

In [894]:
# example : dihedral, D6
def sample(n):
    x = rng.integers(low=0,high=6,size=n)
    y = rng.integers(low=0,high=2,size=n)
    return np.vstack([x,y]).T
def mult(e1,e2):
    x1 = e1[:,0]
    y1 = e1[:,1]
    x2 = e2[:,0]
    y2 = e2[:,1]
    x = (x1 - x2) % 6
    y = (y1 + y2) % 2
    return np.vstack([x,y]).T
# example2 : c6 x 1
def sample2(n):
    x = rng.integers(low=0,high=6,size=n)
    y = np.ones(n)
    return np.vstack([x,y]).T
def mult2(e1,e2):
    x1 = e1[:,0]
    y1 = e1[:,1]
    x2 = e2[:,0]
    y2 = e2[:,1]
    x = (x1 + x2) % 6
    y = np.ones(y1.shape)
    return np.vstack([x,y]).T
# def any_sample(n):
#     x = rng.integers(low=0,high=10,size=n)
#     y = rng.integers(low=0,high=10,size=n)
#     return np.vstack([x,y]).T
# def valid(data):
#     x = data[:,0]
#     y = data[:,1]
#     return np.where(x<6,True,False)*np.where(y<2,True,False)

In [895]:
device = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # as other apps are currently using gpu

n = 4000
tn = 5

x1 = torch.Tensor(sample(n+tn)).to(device)
x2 = torch.Tensor(sample(n+tn)).to(device)
y  = torch.Tensor(mult(x1,x2)).to(device)


# x1 = torch.Tensor(sample2(n+tn)).to(device)
# x2 = torch.Tensor(sample2(n+tn)).to(device)
# y  = torch.Tensor(mult2(x1,x2)).to(device)

In [896]:
batch_size=30

# AE_G() is modular arithmetic, known to work
# AE_G2() is an actual rep, but without matrices...
# AE_G3() is a proper C 2x2 matrix rep, but can only represent cyclic groups...
# AE_G4() is a proper C 2x2 matrix rep, can represent cyclic and dihedral

nets = [AE_G4()]  # AE(), AE_M(), AE_C(), do not work
err = []
for net in nets:
    net.to(device)
    criterion = torch.nn.MSELoss() #CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)
    # optimizer = optim.ASGD(net.parameters(), lr=0.002)
    for epoch in tqdm(range(24)): 
        running_loss = 0.0
        for i in range(n//batch_size):

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize

            outputs = torch.vstack([net(x1[j],x2[j]) for j in range(i,i+batch_size)])
            loss = criterion(outputs, y[i:i+batch_size])
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
        # print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
        # running_loss = 0.0
    err.append(criterion(torch.vstack([net(x1i,x2i) for x1i,x2i in zip(x1[n:],x2[n:])]),y[n:]).item())

100%|██████████████████████████████████████████████████████████████████████████████████| 24/24 [01:18<00:00,  3.26s/it]


In [892]:
[print(f'Net: {type(n).__name__} has test MSE error e={e}') for n,e in zip(nets,err)]

[]

In [885]:
[nets[0](x1i,x2i) for x1i,x2i in zip(x1[n:],x2[n:])]

[tensor([2.0856, 0.5882], grad_fn=<AddBackward0>),
 tensor([2.0856, 0.5882], grad_fn=<AddBackward0>),
 tensor([2.0856, 0.5882], grad_fn=<AddBackward0>),
 tensor([2.0856, 0.5882], grad_fn=<AddBackward0>),
 tensor([2.0856, 0.5882], grad_fn=<AddBackward0>)]

In [886]:
y[n:]

tensor([[3., 0.],
        [5., 1.],
        [1., 0.],
        [1., 0.],
        [4., 1.]])

In [887]:
[print(f'Net: {type(n).__name__} has weights={n.weights}') for n in nets]

Net: AE_G4 has weights=Parameter containing:
tensor([0.9805, 0.9885, 0.9344, 0.8784, 0.7432, 0.6506, 0.8213, 0.5637, 0.5261,
        0.8257], requires_grad=True)


[None]