In [None]:
from itertools import count
import math
from numba import jit


class PAdicNumbers:
    # the digits of the number, len(digits) = accuracy
    # [p^0, p^1, p^2, p^3, ...]
    digits: list[int]
    p: int

    def __init__(self, p: int = 5, digits = None) -> None:
        if digits is None:
            digits = []

        self.digits = digits
        self.p = p

    # converts a decimal to a p-adic number
    @staticmethod
    def __set(p: int, number: int, limit: int):
        # first find the highest power of p
        power = None
        total = 1
        for n in range(limit):
            total *= p
            if total > number:
                power = n
                break
    
        # overflow check
        if power is None:
            power = limit - 1
            number = number % total

        total //= p
        
        # now compute the array
        # mutates the number parameter
        digits = [0] * (power + 1)
        for n in range(power, -1, -1):
            diviser = int(number // total)
            remainder = number % total

            digits[n] = diviser

            number = remainder
            total //= p

        return digits


    def set(self, number: int, limit: int = 30):
        self.digits = self.__set(self.p, number, limit)
        return self


    def form(self, digits: str):
        self.digits = list(digits)
        return self

    def precision(self, prec: int = 10):
        if len(self.digits) < prec:
            self.digits.extend([0] * (prec - len(self.digits)))
        return self


    def dec(self):
        total = 0
        for index, n in enumerate(self.digits):
            total += n * self.p ** index

        return total


    def __repr__(self) -> str:
        return f"p-adic({self.p}, {self.digits})"

    def __str__(self) -> str:
        return f"(...{''.join(map(str, reversed(self.digits)))}){self.p}"


    def strpad(self, pad: int = 10):
        if len(self.digits) > pad:
            numbers = ''.join(map(str, reversed(self.digits[:pad])))
        else:
            numbers = '0' * (pad - len(self.digits)) + ''.join(map(str, reversed(self.digits)))

        return f"(...{numbers}){self.p}"



    # retrieves the magnitude
    def mag(self):
        # get number of zeros
        zeros = 0
        for n in range(len(self.digits)):
            if self.digits[n] == 0:
                zeros += 1
            else:
                break

        # divide
        largest = self.p ** zeros
        if largest == 0:
            return math.inf

        return 1.0 / largest


    def is_unit(self, ep: float = 0.0001):
        return abs(self.mag() - 1.0) < ep

    # TODO: ADD ERROR CHECKING

    def __add__(self, other):
        # add digit wise, carry too
        carry = 0
        base = self.p

        # pad digits
        n1digits = self.digits.copy()
        n2digits = other.digits.copy()
        n1len, n2len = len(n1digits), len(n2digits)
        if n1len > n2len:
            # pad n2
            n2digits.extend([0] * (n1len-n2len))
        else:
            n1digits.extend([0] * (n2len-n1len))
            

        result_digits = [0] * max(n1len, n2len)
        for n in range(max(n1len, n2len)):
            n1 = n1digits[n]
            n2 = n2digits[n]
            res = carry + n1 + n2
            if res >= base:
                carry = 1
                res -= base
            else:
                carry = 0

            result_digits[n] = res

        return PAdicNumbers(base, result_digits)


    def __neg__(self):
        # returns -n
        ndigits = [0] * len(self.digits)

        # this first element is different
        ndigits[0] = self.p - self.digits[0]

        # the rest are p-1-digit[i]
        for i in range(1, len(ndigits)):
            ndigits[i] = self.p - 1 - self.digits[i]


        return PAdicNumbers(self.p, ndigits)


    def __sub__(self, other):
        return self + -other


    def recip(self):
        # returns 1/n, assume that p is prime
        # and that n is a unit number
        if not self.is_unit():
            raise Exception("cannot invert a non-unit number")
        
        # first build the multiplication table
        # only possible if p is prime
        # table[a][b] = a * b
        # itable[a][a * b] = b
        itable = {}
        for i in range(self.p):
            itable[i] = {}
            for j in range(self.p):
                r = (i * j) % self.p
                itable[i][r] = j

        dlen = len(self.digits)
        digits = [0] * (dlen)

        # this is the n * 1/n partial result
        # aim to make the result:
        # n * 1/n = 0000001(p)
        result = [0] * (dlen * 2)

        for j in range(dlen):  # this is to fill the jth digit of 1/n
            carry = 0
            for i in range(dlen):
                if i == 0 and j == 0:
                    # try to make result[0] = 1
                    d = self.digits[i]
                    b = itable[d][1]
                    digits[j] = b
                    # compute carry
                    carry = (d * b) // self.p
                    result[j] = 1
                    continue

                if i == 0:
                    # try to make result[j] = 0
                    required = self.p - result[j]
                    if required == self.p:
                        required = 0
                        
                    d = self.digits[i]
                    if d == -3 or required == -3:
                        print(d, required, result[j])
                    b = itable[d][required]
                    digits[j] = b

                    # compute carry
                    carry = (d * b) // self.p + (1 if required != 0 else 0)  # the additional carry from the addition
                    result[j] = 0
                    continue
                
                # for the remaining of i other than index 0
                r = digits[j] * self.digits[i] + carry
                nd = result[j + i] + r % self.p
                result[j + i] = nd % self.p
                carry = r // self.p + nd // self.p

        return PAdicNumbers(self.p, digits)


    def __mul__(self, other):
        n1digits = self.digits
        n2digits = other.digits

        n1len, n2len = len(n1digits), len(n2digits)
        reslen = max(n1len, n2len)

        # resulting digits
        digits = [0] * reslen

        for j in range(n2len):  # for each digit in second number
            # the carry
            carry = 0 
            for i in range(n1len):  # for each digit in first number
                if (j+i) >= reslen:
                    break

                total = n2digits[j] * n1digits[i] + carry  # compute digit mul
                carry = total // self.p
                num = total % self.p

                nd = digits[j + i] + num
                digits[j + i] = nd % self.p # the ith, shifted by j places
                carry += nd // self.p

        return PAdicNumbers(self.p, digits)


    def __floordiv__(self, other):
        return self * other.recip()


    # ALSO EXPLORE IMAGINARY NUMBERS

    # TODO: WRITE NEWTON's METHOD FOR THIS


PA = PAdicNumbers


# number = PAdicNumbers()
# for i in range(10):
#     number.set(pow(2, pow(5, i)), 30)
#     print(number.strpad(20))
#     # print(number.magnitude())
#     # print(number.as_dec())

# n1 = PAdicNumbers().set(12).precision()
# n2 = PAdicNumbers().set(124)
# res = n1 + n2
# print(res)
# print(res.as_dec())

# n1 = PA()
# n2 = PA()
# n1.set(323154513451).precision()
# n2.set(234534534).precision()
# # n2.set(12)
# # r = n1 // n2

# print(n1 // n2)
# print(n2.recip())
# print(r)
# print(r.dec())
# print(n1 * n1.recip())
# print((n1 * n2).dec())

total = PA()
total.set(0).precision(100)

# for i in range(1000000):
    # total = total + PA().set(i)

t = PA().set(12).precision(100)
t = (-t).recip()

print(t)
