# Problem statment

given an array 'arr' of elements find the subset with maximum sum, less than 'max_sum'

In [91]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
import random
import time
import numpy as np

In [84]:
def segmoid(x): 
    return 1 / (1 + jnp.exp(-x))

def tanh(x):
    return (jnp.tanh(x) + 1) / 2

def get_sum(arr, str):
    # softmax the strength to 0-1
    strength = jnp.array(str)
    strength = tanh(strength)
    # strength = [1 if s > 0.5 else 0 for s in strength]
    return jnp.sum(jnp.array(strength) * jnp.array(arr))

def loss_fn(arr, str, max_sum):
    return jnp.abs(get_sum(arr, str) - max_sum + 0.0)

In [113]:
def lr_schduler(i):
    return 0.1 

def solve(arr, max_sum):
    arr = jnp.array(arr)
    iterations_num = 1000
    strength = jnp.zeros(len(arr))
    loss_value_and_grad = value_and_grad(loss_fn, 1)

    for i in range(iterations_num):
        loss, grad = loss_value_and_grad(arr, strength, max_sum)
        strength -= lr_schduler(i) * grad
        if grad.max() < 1e-9:
            break
        if i % 100 == 0 and False:
            print(f'iteration {i}, loss {loss}')

    # print(f'final loss {loss}')
    strength = tanh(strength)
    
    assert len(arr) == len(strength)

    ids = [i for i in range(len(arr))]
    # sort ids by strength
    ids = [i for _, i in sorted(zip(strength, ids))]
    cur_sum = 0
    new_strength = [0 for _ in range(len(arr))]
    for i in ids:
        if cur_sum + arr[i] <= max_sum:
            cur_sum += arr[i]
            new_strength[i] = 1
    
    strength = jnp.array(new_strength)

    return sum([a * s for a, s in zip(arr, strength)])    
   


In [114]:
random.seed(420)
max_sum = 5
arr = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
arr = random.sample(arr, len(arr))

solve(arr, max_sum)

Array(5, dtype=int32)

In [115]:
test_count = 1000

start = time.time()

avg_loss = 0

for _ in range(test_count):
    arr = jnp.array(np.random.randint(1, 10, 100))
    max_sum = random.randint(sum(arr) // 2, sum(arr))

    total_sum = solve(arr, max_sum)

    assert int(total_sum) <= int(max_sum)
    
    avg_loss += max_sum - total_sum

    if _ % 100 == 0:
        print(f"test {_}")
    
print(f"avg_loss: {avg_loss / test_count}")
print(f"total time: {time.time() - start}")
print(f"Average time: {(time.time() - start) / test_count}")

test 0
test 100
test 200
test 300
test 400
test 500
test 600
test 700
test 800
test 900
avg_loss: 3.4260001182556152
total time: 18.957711219787598
Average time: 0.01895773434638977
