<a href="https://colab.research.google.com/github/andylehti/Polyhedral-Index-Partition/blob/main/Polyhedral_Index_Partition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
from functools import lru_cache
from math import log10, ceil
import math
from mpmath import mp
import mpmath as mpm
sys.set_int_max_str_digits(0)
sys.setrecursionlimit(10000)

@lru_cache(maxsize=None)
def polyHedralIndex(x, y):
    if x < 1: return 0
    return int(mp.binomial(x + y - 1, y))

def approximate(n, t, x=1):
    e = polyHedralIndex(x, t)
    while n > e:
        x *= 10
        e = polyHedralIndex(x, t)
    if x == 10:
        while e > n:
            x -= 1
            e = polyHedralIndex(x, t)
        return x
    l, h = x // 10, x
    while l <= h:
        m = (l + h) // 2
        e = polyHedralIndex(m, t)
        if e == n: return m
        elif e < n: l = m + 1
        else: h = m - 1
    return l - 1 if polyHedralIndex(l - 1, t) < n else l

def getPair(k):
    k = mpm.mpf(k)
    s = mpm.floor(((-1 + (1 + 8 * k) ** 0.5) / 2))
    i = polyHedralIndex(s, 2)
    return int(s - (k - i)), int(k - i)

def getPartition(n, y):
    if y < 2: return [n]
    if y == 2: return getPair(n)
    a = approximate(n, y)
    p = polyHedralIndex(a, y)
    r = int(n - p)
    c = getPartition(r, y - 1)
    a = int(a - sum(c))
    return (a,) + c

def partition(n, y=2):
    mpm.mp.dps = 100
    mpm.mp.dps = 100 + len(str(n)) * 2
    return list(getPartition(n, y))

def pairInverse(a, b):
    a, b = mpm.mpf(a), mpm.mpf(b)
    n = a + b
    return int(mpm.nint(n * (n + 1) / 2 + b))

def getInverse(*a):
    a = list(a)
    s = len(''.join(map(str, a)))
    mpm.mp.dps = s * 2
    if len(a) <= 1: return a[-1]
    return inversePartition(*a)

def inversePartition(*c):
    if len(c) == 2: return pairInverse(*c)
    n = sum(c)
    r = inversePartition(*c[1:])
    return int(polyHedralIndex(n, len(c)) + r)

In [None]:
p = 320 # change for the number of partitions

n = 123456789 # change for the index
a = partition(n, p)
r = getInverse(*a)
print(a)
print(r)
print(r == n)

In [None]:
def calculateIndices(n, y):
    p = partition(n, y)
    a = sum(p)
    e = len(p) - 1
    s = getInverse(*([a] + [0] * e))
    c = getInverse(*([0] * e + [a]))

    return {
        "integer_partitions": p,
        "cumulative_index": n,
        "integer_index": a,
        "set_index": e + 1,
        "partition_index_start": s,
        "partition_index_end": c,
        "partition_index": n - s
    }

result = calculateIndices(123456879, 3)
print(result)

In [None]:
# increasing complexity test

n = 0
c = 1
f = 1
while True:
    a = partition(n, p)
    r = getInverse(*a)
    print(n, a, r, n == r)
    c += 1
    f *= c
    n += f
    if c == 100: break