In [1]:
import math
from decimal import *

In [8]:
def iterative(n):
    r = math.sqrt(2)
    s = 2+math.sqrt(2)
    p = 1
    total = 0
    while n > 0:
        total += p * (math.floor(n*r)*(math.floor(n*r)+1))//2
        n = math.floor(math.floor(n*r)/s)
        total -= p * n * (n+1)
        p = -p
    return math.trunc(total)

def iterativeSolution(str_n):
    return str(iterative(int(str_n)))

def recursiveSolution(str_n):
    r = math.sqrt(2)
    s = 2 + r
    return str(recursive(int(str_n), r, s))

def recursive(n,r,s):
    if n <= 1:
        return n
    N = math.floor(2*n/r)
    m = math.floor((N+1)/s)
    return (N*(N+1))//2 - m*(m+1) - recursive(m, r, s)

In [9]:
#TEST CASES
inputs = ["5", "77", "100", "1000000"]
expected = ["19", "4208", "7092", "707106988293"]
for i,e in zip(inputs, expected):
    a = recursiveSolution(i)
    print(f"SOLUTION({i}) = {a}")
    if a == e:
        print("\tTEST PASSED")
    else:
        print(f"\tTEST FAILED, EXPECTED {e}")

SOLUTION(5) = 19
	TEST PASSED
SOLUTION(77) = 4208
	TEST PASSED
SOLUTION(100) = 7092
	TEST PASSED
SOLUTION(1000000) = 707106988293
	TEST PASSED


In [10]:
i = "10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
print(recursiveSolution(i))

70710678118654750137209464893756489664282528908927448077495338816220205099787725059874554667652483864909496054855492044098629130629740634370082827683252053667898210120735991631921374343849262239831454


#The problem is asking us to efficiently calculate partial sums of beatty sequences with r = sqrt(2).
#An efficient solution stems from the Rayleigh-Beatty theorem stating that all positive numbers belong
#to either B_r or B_s. B_r is the Beatty sequence generated from sqrt(2). B_s is the Beatty sequence
#generated from s = r/(r-1). As a result, sum_i=1^n(i) (or the triangular numbers) can be broken down
#into two sums, one over terms of B_s and one over terms of B_r. Luckily, in this specific case
#(though I suspect a similar process can be executed for other irrational r values), s = 2+sqrt(2).
#This can be confirmed (or derived) with some basic algebra. This allows us to further describe
#partial sums of B_s in terms of partial sums of B_r since floor(i(2+sqrt(2))) = floor(2i+sqrt(2)i)
#= 2i + floor(sqrt(2)i). Overall, this lets us determine that S_n = (N(N+1))/2 - m(m+1) - S_m
#Where S_n is the partial sum of the beatty sequence of r from i = 1 to n, N = floor(nr) (the largest
#term to be summed in S_n), m = floor(N/s) (the highest value of i in S_m), and S_m is the partial 
#sum of the beatty sequence of r from i = 1 to n.
#
#Now that we have a theoretical solution, we can start to think about a practical solution -
#that is to say a solution function. This is the part that was most annoying for me, though in the
#end, I grew to appreciate the process as it allowed me to apply concepts from my numerical 
#analysis classes that I have only been able to apply sparingly. First, we have to consider that,
#given our limited precision in calculating sqrt(2), it is generally unwise to multiply sqrt(2) by a 
#large value. Instead, we can multiply by 2/sqrt(2) = sqrt(2), which will generally be more accurate.
#Additionally, we can get a more precise value for sqrt(2) by using the decimal library in python,
#though this, without numerical considerations would not be enough. Finally, we have to consider that 
#most of the nubmers computed will be massive, and therefore it would be inappropriate to use floor
#or trunc, and we should instead directly convert to long. All of these considerations will allow
#accurate coputations.

from decimal import *
#Recursive function for calculating partial sum of beatty sequence
def recursive(n, r, s):
    if n <= 1:
        return n
    #Avoid truncating - just convert to long
    N = long((2*n)/r)
    m = long(N/s)
    return (N*(N+1))//2 - m*(m+1) - recursive(m,r,s)
#Iterative function for calculating partial sum of beatty sequence
def iterative(n, r, s):
    #Keep track of parity (negative or positive) when calculating sum
    p = 1
    #Make sure to use long instead of int
    total = long(0)
    while n > 0:
        N = long((2*n)/r)
        total += p * (N*(N+1))//2
        n = long(N/s)
        total -= p * n*(n+1)
        p = -p
    return total
def solution(str_n):
    #Increase precision
    getcontext().prec = 10000
    r = Decimal(2).sqrt()
    #Use 2+sqrt(2) to avoid computation issues associated with r/(r-1)
    s = Decimal(2) + r
    #Call auxilliary function (either recursive or iterative)
    return str(iterative(long(str_n), r, s))