In [16]:
import math

def long_divide(a,b):
	q = a//b
	adiffr0 = a - q*b
	adiff0 = abs(adiffr0)
	adiffr1 = adiffr0 - b
	adiff1 = abs(adiffr1)
	if adiff0 < adiff1:
		return q,adiffr0
	else:
		return q+1,adiffr1

def gcd(a,b,gcd_only=True):
	carda = (1,0)
	cardb = (0,1)

	q,r = long_divide(a,b)
	cardc = (carda[0] - (q*cardb[0]),carda[1] - (q*cardb[1]))
	carda = cardb
	cardb = cardc
	a = b
	b = r

	while r != 0:
		q, r = long_divide(a, b)
		cardc = (carda[0]-(q*cardb[0]), carda[1]-(q*cardb[1]))
		carda = cardb
		cardb = cardc
		a = b
		b = r

	if a < 0:
		a = -a
		carda = (-carda[0],-carda[1])

	if gcd_only:
		return a
	else:
		return a,carda
    
def modi(a,n):
	if a < 0:
		return n*(-(a // n)) + a
	return a % n

class p_adic:
    
    def __init__(self,p,a,n,reduce=True,print_as='exp'):
        if reduce and (a == 0):
            n = 0#zero should always look like this
        while reduce and (a != 0) and ((a % p) == 0):
            a //= p
            n -= 1
        
        self.p = p
        self.a = a
        self.n = n
        
        self.print_as = print_as
        
        self.denom = p**n
        
        self.par = None
        self.chld = None
        
    def metric(self,other):
        if self < other:
            return other.metric(self)
        
        z = self - other
#         return z.p**z.n#this is the technically correct version
        return z.n
    
    def __repr__(self):
        if self.n == 0:
            return str(self.a)
        if self.print_as == 'exp':
            return "{}*{}^({})".format(self.a,self.p,-self.n)
        elif self.print_as == 'frac':
            return "{}/{}".format(self.a,self.denom)
    
    def s(self,pm):
#         if (self.a == 1) and (self.n == 0):#make 1 act like 1/2
#             numerator = 0 if pm == -1 else 1
#             return p_adic(self.p,numerator,0,print_as=self.print_as)
        if (pm == -1) and (self.a <= 1):
            return p_adic(self.p,-self.a,self.n,print_as=self.print_as)
        if (pm == 1) and (self.a >= (self.denom-1)):
            return p_adic(self.p,self.denom+self.a,self.n,print_as=self.print_as)#1/p^0 = 1/1 = 1
        
        if modi(self.a,self.p) != modi(-pm,self.p):
            return p_adic(self.p,self.a + pm,self.n,print_as=self.print_as)
        else:
            return p_adic(self.p,self.a + (2*pm),self.n,reduce=False,print_as=self.print_as)
    
    def is_ancestor_of(self,other):
        if other.p != self.p:
            raise TypeError("{} and {} are not in the same ring")
        if other.n < self.n:
            #this case is necessary to handle the "1/3 is within the bounds of 2/9" case
            return False
        if (other.n == self.n) and (other.a == self.a):
            return True#I am my own ancestor
        
        cousin_prev = self.s(-1)
        cousin_next = self.s(1)
        
        over_two_scale = self.p**(other.n - self.n)
        mo = (cousin_prev.a + self.a)*over_two_scale
        po = (cousin_next.a + self.a)*over_two_scale
        cmp = other.a*2#double this instead of halving the cousins to keep things in ints
        
        return (mo < cmp) and (cmp < po)
        
    
    def __lt__(self,other):
        diff = self - other
        return diff.a < 0
    
    def __le__(self,other):
        return (self == other) or (self < other)
    
    def __ge__(self,other):
        return not (self < other)
    
    def __eq__(self,other):
        return (self.p == other.p) and (self.a == other.a) and (self.n == other.n)
    
    def __gt__(self,other):
        return (self >= other) and (not (self == other))
    
    def __add__(self,other):
        if other.p != self.p:
            raise TypeError("{} and {} are not in the same ring")
        if self.n ==  other.n:
            return p_adic(self.p,self.a+other.a,self.n,print_as=self.print_as)
        elif self.n < other.n:
            return p_adic(self.p,((self.p**(other.n - self.n))*self.a)+other.a,
                          other.n,print_as=self.print_as)
        else:
            return p_adic(self.p,self.a+(other.a*(self.p**(self.n - other.n))),
                          self.n,print_as=self.print_as)
    
    def __abs__(self):
        return p_adic(self.p,abs(self.a),self.n,print_as=self.print_as)
    
    def __neg__(self):
        return p_adic(self.p,-self.a,self.n,print_as=self.print_as)
    
    def __sub__(self,other):
        return self + (-other)
    
    def __mul__(self,other):
        return p_adic(self.p,self.a*other.a,self.n+other.n,print_as=self.print_as)
    
    def parent(self):
        if self.par is not None:
            return self.par
        opt1 = (self.a - (self.a % self.p))//self.p#go down
        if (opt1 % self.p) != 0:
            return p_adic(self.p,opt1,self.n-1,print_as=self.print_as)
        opt2 = (self.a + (self.p - (self.a % self.p)))#go up
        self.par = p_adic(self.p,opt2//self.p,self.n-1,print_as=self.print_as)
        return self.par
    
    def children(self):
        if self.chld is not None:
            return self.chld
        cousin_prev = self.s(-1)
        ch_denom = self.n+1
        ch1_numerator = int(math.ceil((cousin_prev.a + self.a)*self.p/2))
        #just brute-force the actual start since it's only like a O(p) calculation
        ch1 = p_adic(self.p,ch1_numerator,ch_denom,print_as=self.print_as)
        ch1inc = p_adic(self.p,2*ch1_numerator,ch_denom,print_as=self.print_as)
        double_midpoint = self + cousin_prev
        while ch1inc <= double_midpoint:
            ch1_numerator += 1
            ch1 = p_adic(self.p,ch1_numerator,ch_denom,print_as=self.print_as)
            ch1inc = p_adic(self.p,2*ch1_numerator,ch_denom,print_as=self.print_as)
        children = [ch1]
        ch_num = ch1_numerator + 1
        ch_inc = ch_num*2
        double_upper_midpoint = self.p*(self.a + self.s(1).a)
        ch_denom_val = self.p**(ch_denom)
        while (ch_num < ch_denom_val) and (ch_inc < double_upper_midpoint):
            if (ch_num % self.p) != 0:
                children.append(p_adic(self.p,ch_num,ch_denom,print_as=self.print_as))
            ch_num += 1
            ch_inc += 2
        
        self.chld = children
        return self.chld
    
    def greedy_search(self,target):
#         print(self)
        if self == target:
#             print("WE HAVE ARRIVED")
            return True
        
        go_down = self.is_ancestor_of(target)
        
        if not go_down:
            #then go up!
            return (self.parent()).greedy_search(target)
        
        #otherwise it must be down!
        goto_ch = None
        goto_dist = p_adic(self.p,1,0,print_as=self.print_as)#biggest possible distance is 1
        for ch in self.children():
            dist_this = abs(ch - target)
            if dist_this < goto_dist:
                goto_ch = ch
                goto_dist = dist_this
        
        return goto_ch.greedy_search(target)

In [22]:
p = 3
x = p_adic(p,26,4,print_as='frac')
for ch in x.children():
    print(x.metric(ch))

5
5
5


In [77]:
ct = 0
p = 3
n = 1
a = 1
k = 2
y = p_adic(p,a,n,print_as='frac')
print(y)
for a in range(1,p**(n+1)):
    x = p_adic(p,a,n+1,print_as='frac')
    if x.metric(y) == k:
        ct += 1
        print(x,x.metric(y))
ct

1/3
1/9 2
2/9 2
4/9 2
5/9 2
7/9 2
8/9 2


6

In [8]:
x.metric(y)

81