In [25]:
import numpy as np
from scipy import sparse
from scipy.sparse import coo_matrix,diags
from scipy.sparse.linalg import inv
import gudhi as gd
import copy
import random
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
import torch
import torch.nn as nn
from itertools import permutations


In [26]:
def perm_parity(lst):
    ''' Given a permutation of the digits 0..N in order as a list, 
    returns its parity (or sign): +1 for even parity; -1 for odd.
    '''
    parity = 1
    for i in range(0,len(lst)-1):
        if lst[i] != i:
            parity *= -1
            mn = min(range(i,len(lst)), key=lst.__getitem__)
            lst[i],lst[mn] = lst[mn],lst[i]
    return parity   

#permute the entries of a simplex in torch tensor format
def permute_simplex(simplex,perm):
    """ permute the entries of a simplex in torch tensor format"""
    permuted_simplex = torch.zeros(simplex.shape)
    for i in range(simplex.shape[0]):
        permuted_simplex[i] = simplex[perm[i]]
    return permuted_simplex



In [48]:
# initialise a AS cochain as an MLP

# m = dimension of space the complex lives in 
# k = dimension of cochain 

m =2
k = 1
f = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, 1)
)
simp = torch.tensor([(1,0),(0,1)]).float() 
simp = simp.reshape(1,-1)
f(simp)  



tensor([[0.0160]], grad_fn=<AddmmBackward0>)

In [28]:
## we can also learn l different cochains 
l=2
F = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, l)
)

In [59]:
## Evaluate one cochain on a simplex 

def cochain_eval(cochain,simplex):

    """ Evaluate a k-cochains on a k-simplex """
    
    # check size of simplex is compatible with cochain
    assert simplex.shape == (m,(k+1)), "dimension of simplex and cochain does not match"
    simplex.reshape(1,-1)
    simplex = simplex.reshape(1,-1)
    out = cochain(simplex)[0]
    print(out.shape)
    return out 



simp = torch.tensor([(1,0),(0,1)]).float()
res = cochain_eval(f,simp)
print(res.shape)
print(res)


torch.Size([1])
torch.Size([1])
tensor([0.0160], grad_fn=<SelectBackward0>)


In [44]:
def mult_cochain_eval(cochains,simplex):
    """ Evaluate a set of l k-cochains on a k-simplex """
    
    # check size of simplex is compatible with cochain
    assert simplex.shape == (m,(k+1)), "dimension of simplex and cochain does not match"
    
    l = cochains[-1].out_features
    simplex.reshape(1,-1)
    simplex = simplex.reshape(1,-1)
    
    out = torch.zeros(l)
    
    out = cochains(simplex)
    
    
    return out 

simp = torch.tensor([(1,0),(0,1)]).float()
cochain_eval(F,simp)

tensor([[-0.1241,  0.1047]], grad_fn=<AddmmBackward0>)


tensor([[-0.1241,  0.1047]], grad_fn=<AddmmBackward0>)

In [74]:
### FIX this function later!!!!!

def alternate_cochain(cochain,simplex):
    
    l = cochain[-1].out_features
    #print(l)
    alt = torch.zeros(l)
    
    perm = permutations(range(simplex.shape[0]))
    for i in list(perm): 
        simplex = permute_simplex(simplex,i)
        #bla = cochain_eval(cochain,simplex)
        #print(bla.shape)
        #print(i)        
        alt = alt + perm_parity(i)*cochain_eval(cochain,simplex)[0]
    return alt 

In [75]:
simp = torch.tensor([(1,0),(0,1)]).float() 
a = alternate_cochain(f,simp)
print(a)

torch.Size([1])


TypeError: 'tuple' object does not support item assignment

In [17]:
perm = permutations(range(simp.shape[0]))
simp = torch.tensor([(1,0,3),(0,1,2)]).float() 



tensor([[1., 0., 3.],
        [0., 1., 2.]])
tensor(0.)
tensor(0.)


(0, 1, 2)
(0, 2, 1)
(1, 0, 2)
(1, 2, 0)
(2, 0, 1)
(2, 1, 0)


In [22]:
perm = permutations(range(2))
simp = torch.tensor([(1,0,3),(0,1,2)]).float()   
for i in perm: 
    print(permute_simplex(simp,i))
    




tensor([[1., 0., 3.],
        [0., 1., 2.]])
tensor([[0., 1., 2.],
        [1., 0., 3.]])
