In [1]:
import time

from dns.dnssec import validate
from ramanujantools.cmf import pFq
import sympy as sp
from sympy.solvers import linsolve
from typing import List, Dict, FrozenSet, Set

from sympy.unify import unify


In [2]:
pi = pFq(2, 1, sp.Rational(1, 2))
pi

pFq((2, 1, 1/2, True, True))

In [3]:
def zero_det_solve(mat: sp.Matrix) -> Set[FrozenSet]:
    return freeze(sp.solve(mat.det()))

def undefined_solve(mat: sp.Matrix) -> (list, list):
    l = []
    for v in mat.iter_values():
        if (den := v.as_numer_denom()[1]) == 1:
            continue
        for sym in den.free_symbols:
            for sol in sp.solve(den, sym):
                l.append({sym: sol})
    return freeze(l)

def freeze(l):
    sols = set()
    for sol in l:
        sols.add(tuple(set(sol.items()))[0])
    return sols

def unfreeze(sols):
    clean = [set(sol) for sol in sols]
    solutions = set(tuple(tup) for tup in clean)
    return list(solutions)

shard_data = {}
data = set()
for sym, mat in pi.matrices.items():
    z_det = zero_det_solve(mat)
    undef = undefined_solve(mat)
    uni = z_det.union(undef)
    # print(undef)
    data = data.union(uni)
    # shard_data[sym] = {'0_det': unfreeze(z_det), 'undef': unfreeze(undef) , 'unified': unfreeze(uni)}
print(data)
data = unfreeze(data)
data


{(x1, 0), (x0, y0), (x1, y0), (y0, 0), (y0, x1), (y0, x0), (x0, y0 - 1), (x1, y0 - 1), (x0, 0)}


[(0, y0), (0, x1), (y0, x1), (0, x0), (y0, x0), (x0, y0 - 1), (x1, y0 - 1)]

In [4]:
def __extract_shards_data(cmf):
    def solve_shards(mat: sp.Matrix):
        l = []
        for v in mat.iter_values():
            if (den := v.as_numer_denom()[1]) == 1:
                continue
            l += [{sym: sol} for sym in den.free_symbols for sol in sp.solve(den, sym)]
            # for sym in den.free_symbols:
            #     for sol in sp.solve(den, sym):
            #         l.append({sym: sol})
        return freeze(l), freeze(sp.solve(mat.det()))

    def freeze(l):
        return set(tuple(set(sol.items()))[0] for sol in l)

    def unfreeze(sols):
        clean = [set(sol) for sol in sols]
        solutions = set(tuple(tup) for tup in clean)
        return list(solutions)  # type checker shouts, but this is correct!

    data = set()
    for mat in cmf.matrices.values():
        undef, z_det = solve_shards(mat)
        data = data.union(z_det.union(undef))
    return unfreeze(data), list(cmf.matrices.keys())

In [5]:
__extract_shards_data(pi)[0]

[(0, y0), (0, x1), (y0, x1), (0, x0), (y0, x0), (x0, y0 - 1), (x1, y0 - 1)]

In [7]:
def format_expression(l):
    return [exp2 - exp1 for exp1, exp2 in l]

In [8]:
format_expression(__extract_shards_data(pFq(3, 2, 1))[0])

ERROR! Session/line number was not unique in database. History logging moved to new session 248


[x2,
 y0,
 -x0 - x1 - x2 + y0 + y1 - 2,
 x0 - y0,
 x2 - y0 + 1,
 x1,
 x0 + x1 + x2 - y0 - y1 + 2,
 x1 - y1,
 x0 + x1 + x2 - y0 - y1 + 2,
 -x0 + y1 - 1,
 x2 - y0,
 x1 - y1 + 1,
 -x1 + y0 - 1,
 x0,
 x0 + x1 + x2 - y0 - y1 + 2,
 x0 - y1,
 x2 - y1 + 1,
 -x0 + y0 - 1,
 y1,
 x0 + x1 + x2 - y0 - y1 + 2,
 x0 + x1 + x2 - y0 - y1 + 1,
 x1 - y0,
 x2 - y1]

In [9]:
from utils.util_types import *
from ramanujantools.cmf import CMF
from functools import lru_cache
from analysis_stage.searchable import Searchable
from utils.util_types import *


class Shard(Searchable):
    def __init__(self,
                 hps: List[sp.Expr],
                 symbols: List[sp.Symbol],
                 shard_id: Tuple[int | sp.Rational, ...]):
        self.hps = hps
        self.symbols = symbols
        self.shard_id = shard_id

    def __eq__(self, other) -> bool:
        if isinstance(other, Shard):
            return self.shard_id == other.shard_id
        raise NotImplementedError

    def __hash__(self):
        return hash(self.shard_id)

    def in_shard(self, point: Tuple[int | sp.Rational, ...]) -> bool:
        point = {sym: v for sym, v in zip(self.symbols, point)}
        return all(exp.subs(point) == indicator for exp, indicator in zip(self.hps, self.shard_id))

    def __repr__(self):
        return str(self.shard_id)

from itertools import product
from scipy.optimize import linprog
import numpy as np
import time

class ShardExtractor:
    cache: bool = True

    def __init__(self, cmf: CMF, shifts: List[Shift]):
        self.cmf = cmf
        self.shifts = shifts
        self.hps, self.symbols = self.extract_shard_hyperplanes(cmf)

    @staticmethod
    def extract_shard_hyperplanes(cmf: CMF) -> Tuple[List[sp.Expr], List[sp.Symbol]]:
        def solve_shards(mat: sp.Matrix) -> Tuple[Set[EqTup], Set[EqTup]]:
            l = []
            for v in mat.iter_values():
                if (den := v.as_numer_denom()[1]) == 1:
                    continue
                l += [{sym: sol} for sym in den.free_symbols for sol in sp.solve(den, sym)]
            return freeze(l), freeze(sp.solve(mat.det()))

        def freeze(l) -> Set[EqTup]:
            return set(tuple(set(sol.items()))[0] for sol in l)

        def unfreeze(sols: Set[EqTup]) -> List[EqTup]:
            clean = [set(sol) for sol in sols]
            solutions = set(tuple(tup) for tup in clean)
            return list(solutions)  # type checker shouts, but this is correct!

        data = set()
        for mat in cmf.matrices.values():
            undef, z_det = solve_shards(mat)
            data = data.union(z_det.union(undef))
        data = unfreeze(data)
        return [exp1 - exp2 for exp1, exp2 in data], list(cmf.matrices.keys())

    def get_shards(self) -> List[Tuple[int, ...]]:
        @lru_cache(maxsize=128 if self.cache else 0)
        def expr_to_ineq(expr, greater_than_0: bool = True):
            eps = 1e-5
            coeffs = expr.as_coefficients_dict()
            sign = 1 if greater_than_0 else -1
            row = [-sign * coeffs.get(v, 0) for v in self.symbols]
            b = sign * coeffs.get(1, 0) - eps
            return row, b

        def validate_shard(shard: Tuple[int, ...]):
            A, b = [], []
            for ineq, indicator in zip(self.hps, shard):
                row, rhs = expr_to_ineq(ineq, indicator == 1)
                A.append(row)
                b.append(rhs)
            # TODO: I am not sure if we DO ignore false shards we can get by linprog() - i.e. just intersections of many hyperplanes
            return linprog(c=list(np.zeros(len(self.symbols))), A_ub=A, b_ub=b, method="highs").success

        res = [permutation for permutation in product([+1, -1], repeat=len(self.hps)) if validate_shard(permutation)]
        expr_to_ineq.cache_clear()
        return res

    def encode_point(self, point: Tuple[int | sp.Rational, ...]) -> Tuple[int, ...]:
        # TODO: notice that the order of symbols matter to the interpetation of the point!
        #  this is importent in maybe other few cases!
        point = {sym: val for sym, val in zip(self.symbols, point)}
        return tuple((1 if exp.subs(point) > 0 else -1) for exp in self.hps)



In [10]:
# ShardExtractor.extract_shard_hyperplanes(pi)
extractor = ShardExtractor(pi, [None])
extractor.get_shards()

[(1, 1, 1, -1, -1, 1, -1),
 (1, 1, 1, -1, -1, -1, -1),
 (1, 1, -1, -1, -1, 1, -1),
 (1, 1, -1, -1, -1, -1, -1),
 (1, -1, 1, -1, -1, 1, -1),
 (1, -1, -1, -1, -1, 1, -1),
 (-1, 1, 1, -1, -1, 1, -1),
 (-1, 1, 1, -1, -1, -1, -1),
 (-1, -1, 1, -1, -1, 1, -1)]

In [12]:
pFq(3,2,1).matrices.keys()

dict_keys([x0, x1, x2, y0, y1])