In [129]:
import jax
import numpy as onp
import jax.numpy as jnp

In [136]:
class nvector:
    def __init__(self, x):
        self.x = x
    
    def __str__(self):
        return str(f'{len(self.x)}-vector:\n{onp.asarray(self.x)}')

    def __len__(self):
        '''Redefines builtin len() function to return n for an n-vector'''
        return len(self.x)
    
    def __xor__(self, other):
        '''This method redefines xor ^ operator to the wedge operator 
        for n-vectors'''
        return nvector(self.x + other.x)
    
class nform:
    def __init__(self, x):
        self.x = x

    def __str__(self):
        return str(f'{len(self.x)}-form:\n{onp.asarray(self.x)}')
    
    def __len__(self):
        '''Redefines builtin len() function to return n for an n-form'''
        return len(self.x)

    def __xor__(self, other):
        print()
        '''This method redefines xor ^ operator to the wedge operator 
        for n-form'''
        return nform(self.x + other.x)

    def __call__(self, other):
        '''Makes forms callable to implement the contraction operation α(u)'''
        n, m = len(self.x), len(other.x)
        if n != m:
            raise NotImplementedError('Partial Contraction has not been implemented yet')
        elif n in [1,2,3]:
            C = jnp.array([[a.dot(u) for u in other.x] for a in self.x])
            return jnp.linalg.det(C)
        else:
            raise ValueError('That type of contraction is invalid')
    
def hodge_star(a):
    if not isinstance(a, nvector):
        raise ValueError('# (sharp) operator only defined for nvectors')
    elif len(a.x) == 0:
        return vector(1,0,0) ^ vector(0,1,0) ^ vector(0,0,1)
    elif len(a.x) == 1:
        tmp = jnp.array([a.x[2], a.x[0], a.x[1]])
        b = jnp.cross(a.x[0], tmp)
        c = jnp.cross(tmp, b)/b.dot(b)
        return nvector([b, c])
    elif len(a.x) == 2:  
        return nvector([jnp.cross(a.x[0], a.x[0])])
    elif len(a.x) == 3:
        return nvector([])
def sharp(a):
    if not isinstance(a, nvector):
        raise ValueError('# (sharp) operator only defined for nvectors')
    return nform(a.x)

def flat(a):
    if not isinstance(a, nform):
        raise ValueError('# (sharp) operator only defined for nvectors')
    return nvector(a.x)

In [137]:
vector = lambda a, b, c: nvector([jnp.array([a, b, c])])
form = lambda a, b, c: nform([jnp.array([a, b, c])])

In [138]:
α = form(1,0,0)
β = form(0,1,0)
γ = form(0,0,1)
u = vector(1,0,0)
v = vector(0,1,0)
w = vector(1,0,1)

uvw = u^v^w
αβγ = α^β^γ
print(αβγ(uvw))

vuw = v^u^w
print(αβγ(uvw))

print(u^v)
print(v^u)



1.0
1.0
2-vector:
[[1 0 0]
 [0 1 0]]
2-vector:
[[0 1 0]
 [1 0 0]]
