# Graph matrix subspace checking

In [3]:
import cvxpy as cp
import multiprocessing as mp
from tqdm import tqdm 
from tqdm.contrib.itertools import product
import numpy as np
import os
import datetime

from graph_utils import *

In [155]:
def get_rank_and_nullspace(A):
    
    _, v, d = np.linalg.svd(A)
    # такой выбор толеранса используется в матлабе и в нампае
    # https://numpy.org/doc/stable/reference/generated/numpy.linalg.matrix_rank.html
    tol = v.max() * max(A.shape) * np.finfo(A.dtype).eps
    nnz = (v >= tol).sum()
    
    return nnz, d[nnz:].T

In [169]:
def a_is_subspace_of_b(basis_a, basis_b):
    
    orig_rank, _ = get_rank_and_nullspace(basis_b)
    combined_rank, _ = get_rank_and_nullspace(np.hstack([basis_a, basis_b]))
    
    return orig_rank == combined_rank

In [174]:
def find_chain_starts(basis_set):
    
    # Sort by the dimensionality of subspaces
    basis_set = sorted(basis_set, key=lambda x: x[1])
    chain_starts = [basis_set[0]]
    
    # Go over basis sets and check whether their spaces are subspaces of some of chain_starts spaces
    for i in range(1, len(basis_set)):
        
        not_a_subspace = True
        
        for start in chain_starts:
            if a_is_subspace_of_b(basis_set[i][2], start[2]):
                not_a_subspace = False
                break
        
        if not_a_subspace:
            chain_starts.append(basis_set[i])
    
    return chain_starts

In [193]:
ALL_GRAPHS = get_graphs()
del ALL_GRAPHS['cycle']

In [182]:
g = ALL_GRAPHS['cell']
I = find_stable_sets(g)
A, var_matr, num_params = build_triple_equalities(I, g.shape[0])

In [None]:
for name, g in ALL_GRAPHS.items():
    
    print(name)
    I = find_stable_sets(g)
    A, var_matr, num_params = build_triple_equalities(I, g.shape[0])
    
    basis_set = []
    for i, addA in enumerate(build_variance_equalities_iterator(I, var_matr, num_params, only_nonequivalent=True)):
        basis_set.append([i, *get_rank_and_nullspace(np.array(A + addA[0], dtype=np.float64))])
    
    chain_starts = find_chain_starts(basis_set)
    print(len(chain_starts), "dimensions:", ",".join([str(len(x[2][0])) for x in chain_starts]))

dupl
1 dimensions: 8
zigzag
1 dimensions: 6
fork
1 dimensions: 7
big_zig
1 dimensions: 7
big_triag
1 dimensions: 7
ya_big_triag
1 dimensions: 7
ya_big_triag_2
1 dimensions: 7
tang_triag
1 dimensions: 7
square
1 dimensions: 8
center
1 dimensions: 6
ship


In [181]:
len(find_chain_starts(res))

1