# Notebook to learn Alexander Spannier cochains 

In [85]:
import numpy as np
from scipy import sparse
from scipy.sparse import coo_matrix,diags
from scipy.sparse.linalg import inv
import gudhi as gd
import copy
import random
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
import torch
import torch.nn as nn
from itertools import permutations
import math


### Some useful functions 

In [None]:
def generate_diagonal_paths(num_paths=100,eps = 0.2, num_pts = 10):
    
    Paths = []
    for i in range(num_paths): 
        x = np.sort(np.random.uniform(low=-1, high=1, size=num_pts).astype('f'))
        noise_x= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')
        noise_y= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')

        x_trans = np.random.randint(-5,5)
        y_trans = np.random.randint(-5,5)

        x_values = list(x+noise_x +x_trans)
        y_values = list(np.sin(x+noise_y)+y_trans)

        path = np.stack((x_values,y_values))
    
        Paths.append(path.T)
        
    return Paths


def generate_antidiagonal_paths(num_paths=100,eps = 0.2, num_pts = 10):
    
    Paths = []
    for i in range(num_paths): 
        x = np.sort(np.random.uniform(low=-1, high=1, size=num_pts).astype('f'))

        noise_x= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')
        noise_y= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')

        x_trans = np.random.randint(-5,5)
        y_trans = np.random.randint(-5,5)
        x_values = list(x+noise_x +x_trans)
        y_values = list(-np.sin(x+noise_y)+y_trans)
        path = np.stack((x_values,y_values))
    
        Paths.append(path.T)
        
    return Paths
        



def generate_circular_paths(num_paths=100,eps = 0.2, num_pts = 10):
    
    Paths = []
    for i in range(num_paths): 
        endpoint = np.random.randint(0,num_pts)
        
        sample_angles = list(np.sort(np.random.uniform(0,2*np.pi, num_pts)).astype('f'))
        angles= sample_angles[endpoint:]+ sample_angles[:endpoint]
        angles = np.array(angles)
        
        noise_x= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')
        noise_y= np.random.uniform(low= -eps, high = eps,  size = num_pts).astype('f')

        #x_trans = np.random.randint(-5,5)
        #y_trans = np.random.randint(-5,5)

        r = np.random.uniform(0.5, 2.5)

        x_values = r*np.cos(angles)+noise_x
        y_values = r*np.sin(angles)+noise_y

        path = np.stack((x_values,y_values))
    
        Paths.append(path.T)

    return Paths

In [86]:
def perm_parity(lst):
    ''' Given a permutation of the digits 0..N in order as a list, 
    returns its parity (or sign): +1 for even parity; -1 for odd.
    '''
    parity = 1
    for i in range(0,len(lst)-1):
        if lst[i] != i:
            parity *= -1
            mn = min(range(i,len(lst)), key=lst.__getitem__)
            lst[i],lst[mn] = lst[mn],lst[i]
    return parity   

#permute the entries of a simplex in torch tensor format
def permute_simplex(simplex,perm):
    """ permute the entries of a simplex in torch tensor format"""
    permuted_simplex = torch.zeros(simplex.shape)
    for i in range(simplex.shape[0]):
        permuted_simplex[i] = simplex[perm[i]]
    return permuted_simplex



In [87]:
# initialise a AS cochain as an MLP
# m = dimension of space the complex lives in 
# k = dimension of cochain 

m =2
k = 1
f = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, 1)
)

simp = torch.tensor([(1,0),(0,1)]).float() ## a 1-simplex with node values in R^2
print(simp.shape)
simp = simp.reshape(1,-1)
print(simp.shape)
a = f(simp)
print(a)
print(a.shape)



torch.Size([2, 2])
torch.Size([1, 4])
tensor([[0.0152]], grad_fn=<AddmmBackward0>)
torch.Size([1, 1])


In [88]:
## we can also learn l different cochains 
l=2
F = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, l)
)

simp = torch.tensor([(1,0),(0,1)]).float() ## a 1-simplex with node values in R^2
print(simp.shape)
simp = simp.reshape(1,-1)
print(simp)
print(simp.shape)
a = F(simp)[0]
print(a)
print(a.shape)


torch.Size([2, 2])
tensor([[1., 0., 0., 1.]])
torch.Size([1, 4])
tensor([-0.0498,  0.1354], grad_fn=<SelectBackward0>)
torch.Size([2])


In [91]:
def cochain_eval(cochains,simplex):
    """ Evaluate a set of l k-cochains on a k-simplex """
    
    # check size of simplex is compatible with cochain
    assert simplex.shape == ((k+1),m), "dimension of simplex and cochain does not match"
    
    l = cochains[-1].out_features
    new_simplex = simplex.reshape(1,-1)
    out = torch.zeros(l)
    out = cochains(new_simplex)[0]

    return out 

simp = torch.tensor([(1,0),(0,1)]).float()
cochain_eval(F,simp)

tensor([ 0.1956, -0.1005], grad_fn=<SelectBackward0>)

In [92]:
# function to obtain the alternation of a cochain on a simplex

def alternate_cochain(cochain,simplex):
    
    l = cochain[-1].out_features
    s = simplex.shape[0]

    alt = torch.zeros(l)
    perm = permutations(range(simplex.shape[0]))

    for i in list(perm): 
        simplex = permute_simplex(simplex,i)   
        alt += perm_parity(list(i))*cochain_eval(cochain,simplex)
    return alt/math.factorial(s)


simp = torch.tensor([(1,0,),(0,4)]).float() 
print(permute_simplex(simp,[1,0]))

l=2
m = 2
k = 1

F = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, l)
)

a = alternate_cochain(F,simp)
print(a)

tensor([[0., 4.],
        [1., 0.]])
tensor([-0.0317, -0.1362], grad_fn=<DivBackward0>)


In [144]:
p0 = generate_diagonal_paths(num_paths=10,eps = 0.2, num_pts = 5)
path = p0[0]
path[0][0].dtype

dtype('float32')

In [160]:
def cochain_eval_path(cochains,path):
    """ Evaluate a set of l k-cochains on a simplicial complex
     simplicial complex sc as array of simplices"""
    
    out = torch.zeros(path.shape[0],cochains[-1].out_features)

    for i in range(path.shape[0]-1):
        
        simplex = torch.tensor((path[i], path[i+1]))
        #print(simplex)
        out[i] = cochain_eval(cochains,simplex)
    return out


m =2
k = 1
f = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, 1)
)

cochain_eval_path(f,path)

tensor([[-0.1990],
        [-0.1586],
        [-0.1559],
        [-0.1419],
        [ 0.0000]], grad_fn=<CopySlices>)

## Learning part

In [183]:
# generate data
p0 = generate_diagonal_paths(num_paths=10,eps = 0.2, num_pts = 10)
p1 = generate_antidiagonal_paths(num_paths=10,eps = 0.2, num_pts = 10)
p2 = generate_circular_paths(num_paths=10,eps = 0.2, num_pts = 10)

# join together p0, p1, p2
paths = p0+p1+p2

# generate labels
labels = np.concatenate((np.zeros(10),np.ones(10),2*np.ones(10))).astype('f')

# perform a one hot encoding of the labels and transform to torch
labels = torch.nn.functional.one_hot(torch.tensor(labels).to(torch.int64))


In [180]:
labels[0].dtype

dtype('float32')

In [181]:
l=3 ## three classes so three outputs 
m = 2 ## the paths live in R^2
k = 1 # we deal with one simplices 





F = nn.Sequential(
    nn.Linear(m*(k+1), 10),
    nn.ReLU(),
    nn.Linear(10, 200), ## random number 
    nn.ReLU(),
    nn.Linear(200, l)
)


In [219]:
epochs = 1

orig_labels = np.concatenate((np.zeros(10),np.ones(10),2*np.ones(10)))

optim = torch.optim.SGD(F.parameters(), lr=1e-4, momentum=0.9)

criterion = nn.CrossEntropyLoss()


for e in range(epochs):

    print(e)

    ## shuffle the data
    idx = np.random.permutation(len(paths))
    paths = [paths[i] for i in idx]
    labels = labels[idx]
    orig_labels = orig_labels[idx]

    correct_pred = 0

    for i in range(len(paths)):

        p = paths[i]
        y = labels[i]
        
        X = cochain_eval_path(F,p)

        #print(X)
        X = torch.sum(X, dim = 0) 
        
        ## do we want to do this ? 
        print(X)
        #print("*****")

        sm = torch.nn.functional.softmax(X, dim=0) 


        print(sm.shape)
    
        
        print(sm)
        #print(sm)

        # and the target y
        loss = criterion(sm,y.float())
        loss.backward()

        # get the index of the max log-probability
        #pred = sm.argmax(keepdim=True).float()

        #print(labels[i])

        #if pred == orig_labels[i]: ## 
         #   correct_pred += 1
        
        print("y =", y)
        print("sm =", sm)

        optim.step()

        optim.zero_grad()

    #print(correct_pred/len(paths))



0
tensor([ 6.2697,  2.9502, -2.7518], grad_fn=<SumBackward1>)
torch.Size([3])
tensor([9.6498e-01, 3.4903e-02, 1.1656e-04], grad_fn=<SoftmaxBackward0>)


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [214]:
input = torch.randn(3, 5, requires_grad=True)
print(input)

tensor([[ 0.0201, -0.0215,  0.2926, -1.5066, -1.9511],
        [ 0.3444, -0.8287, -0.5833, -0.1126, -1.2810],
        [ 0.6459,  1.0271,  0.3656, -0.3480, -0.5376]], requires_grad=True)


### Graveyard 

In [56]:
## Evaluate one cochain on a simplex 

def cochain_eval(cochain,simplex):

    """ Evaluate a k-cochains on a k-simplex """
    
    # check size of simplex is compatible with cochain
    assert simplex.shape == ((k+1),m), "dimension of simplex and cochain does not match"
    simplex.reshape(1,-1)
    simplex = simplex.reshape(1,-1)
    out = torch.zeros(1)
    out = cochain(simplex)[0]
    print(out.shape)
    return out 


simp = torch.tensor([(1,0),(0,1)]).float()
res = cochain_eval(f,simp)
print(res.shape)
print(res)


torch.Size([1])
torch.Size([1])
tensor([0.0332], grad_fn=<SelectBackward0>)


In [None]:
def path2sc(path):

    sc = torch.zeros(len(path)-1, 4)
    for i in range(len(path)):
        sc[i,0] = path[i]
        sc[i,1] = path[i+1]



    return sc