In [9]:
from bggcomplex import BGGComplex
from fast_module import FastLieAlgebraCompositeModule, FastModuleFactory, BGGCohomology
import numpy as np
%load_ext cython

In [2]:
d='B3'
BGG = BGGComplex(d)
f_factory = FastModuleFactory(BGG.LA)
component_dic = {'b':f_factory.build_component('b','ad'),'u':f_factory.build_component('g','ad')}
components = [[("b",2,'sym'),('u',2,'wedge')]]

module = FastLieAlgebraCompositeModule(f_factory.weight_dic, components, component_dic)

In [3]:
def get_action_tensor(action_mat):
    max_ind = 0 
    for _,j in action_mat.keys():
        if j > max_ind: 
            max_ind = j
    max_ind+=1

    dim_n = len(f_factory.basis['n'])

    extra_rows = [0]*max_ind
    for (_,j), v in action_mat.items():
        if len(v)>1:
            extra_rows[j]+=len(v)-1
    n_extra_rows = max(extra_rows)

    action_tensor = np.zeros((dim_n+n_extra_rows, max_ind, 3),np.int64)
    s_values = np.zeros(max_ind,np.int64)+dim_n
    for (i,j), v in action_mat.items():
        l = len(v)
        if l == 1:
            k,C_ijk = v.items()[0]
            action_tensor[i,j] = (-1, k, C_ijk)
        else: # l>0
            s= s_values[j]
            s_values[j]+=1
            k,C_ijk = v.items()[0]
            action_tensor[i,j] = (s, k, C_ijk)
            count = 0
            for k,C_ijk in v.items()[1:]:
                count+=1
                if count>=l-1:
                    action_tensor[s,j] = (-1,k, C_ijk)
                else:
                    action_tensor[s,j] = (s+1,k,C_ijk)
                s= s_values[j]
                s_values[j]+=1
    return action_tensor

In [4]:
action_tensor_dic = dict()
for key,mod in module.component_dic.items():
    action_tensor_dic[key] = get_action_tensor(mod.action)
module.action_tensor_dic = action_tensor_dic

In [95]:
%%cython
#cython: language_level = 2 
import numpy as np
cimport numpy as np

cdef compute_action(acting_element, action_source, module):
    action_image = np.zeros_like(action_source)
    type_list = module.type_lists[0]
    cdef image_row = 0
    cdef max_rows = len(action_image)

    cdef int row,j
    for col,mod_type in enumerate(type_list):
        action_tensor = module.action_tensor_dic[mod_type]
        for row in range(len(action_source)):
            j = action_source[row,col]
            s,k,Cijk = action_tensor[acting_element,j]
            while s!=0:
                new_row = action_source[row].copy()
                new_row[col] = k
                new_row[-1]*=Cijk
                action_image[image_row] = new_row
                if s==-1:
                    s=0
                else:
                    s,k,Cijk = action_tensor[s,j]
                image_row+=1
                if image_row>=max_rows: # double size of image matrix if we run out of space
                    action_image = np.concatenate([action_image,np.zeros_like(action_image)])
                    max_rows = len(action_image)
    return action_image[:image_row]

cdef check_equal(long [:] row1,long [:] row2,int num_cols):
    cdef int i
    for i in range(num_cols):
        if row1[i]!=row2[i]:
            return False
    else:
        return True
    
cdef col_nonzero(long [:] col, int num_rows):
    """return non-zero indices of a column. np.nonzero doesn't seem to work well with memoryviews."""

    indices = np.zeros(num_rows,np.int32)
    cdef int i
    cdef int j = 0
    for i in range(num_rows):
        if col[i]!=0:
            indices[j]=i
            j+=1
    return indices[:j]

cdef merge_sorted_image(long [:,:] action_image):
    merged_image = np.zeros_like(action_image)
    
    cdef long[:] old_row
    cdef long[:] row
    old_row = np.zeros_like(action_image[0])-1
    
    cdef int row_number = -1
    cdef int num_cols = action_image.shape[1]-1
    cdef int num_rows = action_image.shape[0]
    
    cdef int i
    for i in range(num_rows):
        row = action_image[i]
        if row[-1]!=0:
            if check_equal(row,old_row,num_cols):
                merged_image[row_number,-1] += row[-1]
            else:
                row_number+=1
                merged_image[row_number]=row
                old_row = row
    non_zero_inds = col_nonzero(merged_image[:row_number,-1],row_number)
    return merged_image[non_zero_inds,:]
    
def sort_merge(action_image):
    action_image = action_image[np.lexsort(np.transpose(action_image[:,:-1]))]
    return merge_sorted_image(action_image)

cdef permutation_sign(long [:] row,int num_cols):
    cdef int sign = 1
    cdef int i,j
    for i in range(num_cols):
        for j in range(i+1,num_cols):
            if row[i]==row[j]:
                return 0
            elif row[i]>row[j]:
                sign*=-1
    return sign

cdef sort_cols(module, action_image):
    cdef int col_min = 0
    cdef int num_rows = len(action_image)
    cdef int i
    cdef long[:] row
    
    for _,cols,mod_type in module.components[0]:
        if mod_type == 'wedge':
            for i in range(num_rows):
                row = action_image[i,col_min:col_min+cols]
                action_image[i,-1]*=permutation_sign(row,cols)
        action_image[:,col_min:col_min+cols] = np.sort(action_image[:,col_min:col_min+cols])
        col_min+=cols
        
def action_on_basis(pbw_elt,wmbase,module,factory):
    num_cols = wmbase.shape[1]
    action_list = []
    action_source = np.zeros((wmbase.shape[0], num_cols+2),np.int64)
    action_source[:,:num_cols] = wmbase
    action_source[:,num_cols] = np.arange(len(wmbase))
    action_source[:,-1] = 1
    for monomial,coefficient in pbw_elt.monomial_coefficients().items():
        action_image = action_source.copy()
        action_image[:,-1]*=coefficient
        for term in monomial.to_word_list()[::-1]:
            index = factory.root_to_index[term]
            action_image = compute_action(index, action_image,module)
        action_list.append(action_image)
    action_image = np.concatenate(action_list)
    sort_cols(module,action_image)
    return sort_merge(action_image)

In [97]:
pbw_elt=BGG.compute_maps(BGG.zero_root)[('1','21')]
wmbase = module.weight_components[(0,0,0)][-1][-1]
print(pbw_elt)
%timeit action_on_basis(pbw_elt,wmbase,module,f_factory)

PBW[-alpha[2]]^2
100 loops, best of 3: 17.5 ms per loop
