## Problem statement

Find the minimum value of 'x' such that the sum of values from 1 to 'x' is less than or equal to 'max_sum'. 

In [78]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
import random
import time


In [50]:
def get_value(x):
    return x * (x + 1) // 2

def loss(x, max_sum):
    return jnp.abs(get_value(x) - max_sum)

In [86]:
def lr_schduler(i):
    return 0.1 

def solve(max_sum):
    x = jnp.array(0.0)
    num_iterations = 500
    answer = 0
    best_loss = 1000000

    loss_value_and_grad = value_and_grad(loss)

    for i in range(num_iterations):
        lr = lr_schduler(i)
        loss_value, grad_value = loss_value_and_grad(x, max_sum)

        if i % 100 == 0 and False:  # Print every 100 iterations for better readability
            print(f"Iteration {i}: x = {x}, get_value(x) = {get_value(int(x))}, loss = {loss_value}, derivative = {grad_value}")

        if loss_value < best_loss:
            best_loss = loss_value
            answer = x

        # check if x becomes nan or inf
        if jnp.isnan(x - lr * grad_value) or jnp.isinf(x - lr * grad_value):
            print(f"Warning: x became nan at iteration {i}")
            print(x, grad_value)

        x = x - lr * grad_value

    answer = jnp.round(answer) + 1
    while get_value(answer) > max_sum:
        answer -= 1
    return int(answer)

In [83]:
max_sum = random.randint(1, 100000)

In [84]:

start = time.time()

res = solve(max_sum)
print(f"Answer: {res}")

end = time.time()
print(f"Time: {end - start} seconds")


Iteration 0: x = 0.0, get_value(x) = 0.0, loss = 65243.0, derivative = -0.5
Iteration 100: x = 339.11419677734375, get_value(x) = 57630.0, loss = 7574.22265625, derivative = -339.61419677734375
Iteration 200: x = 374.64227294921875, get_value(x) = 70125.0, loss = 5122.734375, derivative = 375.14227294921875
Iteration 300: x = 338.5440673828125, get_value(x) = 57291.0, loss = 7767.68359375, derivative = -339.0440673828125
Iteration 400: x = 374.0125427246094, get_value(x) = 70125.0, loss = 4886.6953125, derivative = 374.5125427246094
Answer: 360
Time: 0.8963940143585205 seconds


In [85]:
print(max_sum - get_value(res))
print(max_sum - get_value(res + 1))
print(max_sum - get_value(res - 1))

263.0
-98.0
623.0


## Testing~

In [88]:
total_tests = 1000

start = time.time()
for _ in range(total_tests):
    max_sum = random.randint(1, 100000)
    res = solve(max_sum)
    assert get_value(res) <= max_sum
    assert get_value(res + 1) > max_sum
    assert get_value(res - 1) < max_sum

    if _ % 100 == 0:
        print(f"Test {_} passed")

end = time.time()
print(f"Time: {end - start} seconds")
print(f"Average time: {(end - start) / total_tests} seconds")

print("All tests passed!")

Test 0 passed
Test 100 passed
Test 200 passed
Test 300 passed
Test 400 passed
Test 500 passed
Test 600 passed
Test 700 passed
Test 800 passed
Test 900 passed
Time: 883.6656148433685 seconds
Average time: 0.8836656148433686 seconds
All tests passed!
