In [1]:
import os

from jax.example_libraries.optimizers import optimizer

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [2]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from santa import tree_packing

from santa import optimizers

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
NUM_TREES = 32
RNG = jax.random.PRNGKey(42)

In [4]:
problem = tree_packing.create_tree_packing_problem(NUM_TREES)

In [None]:
solution = tree_packing.init_solution(NUM_TREES, RNG, length=2)
solution = problem.eval(solution)

small_noise = optimizers.noise.NoiseOptimizer(noise_level=0.001)
noise = optimizers.noise.NoiseOptimizer(noise_level=0.1)
large_noise = optimizers.noise.NoiseOptimizer(noise_level=0.2)

opt = optimizers.alns.AdaptiveLargeNeighborhoodSearch(
    [optimizers.ruin.RandomRuin(n_remove=1), optimizers.ruin.SpatialRuin(n_remove=2)],
    [optimizers.recreate.RandomRecreate(max_recreate=2)],
    [noise, small_noise, large_noise],
)
opt = optimizers.sa.SimulatedAnnealing(
    opt,
    initial_temp=1e6,
    cooling_rate=0.99995,
    patience=100_000
)
opt = optimizers.combine.RestoreBest(opt, patience=10_000)

opt.set_problem(problem)
opt_state = opt.init_state(solution)
global_state = problem.init_global_state(seed=42)


@jax.jit
def iterations(sol, opt_state, global_state):
    def one_iteration(carry, _):
        sol, opt_state, global_state = carry
        sol, opt_state, global_state = opt.step(sol, opt_state, global_state)
        global_state = global_state.next()
        carry = sol, opt_state, global_state
        return carry, _

    carry = sol, opt_state, global_state
    (sol, opt_state, global_state), _ = jax.lax.scan(one_iteration, carry, length=10_000)
    return sol, opt_state, global_state


flops = iterations.lower(solution, opt_state, global_state).compile().cost_analysis()["flops"]
print(f"Flops: {int(flops)}")

history = {"opt_state": [opt_state], "global_state": [global_state]}
solutions = [solution]
trange = tqdm(range(1_000_000))
for it in trange:
    solution, opt_state, global_state = iterations(solution, opt_state, global_state)
    temp = opt_state['temperature']
    trange.set_description(f"{global_state.best_feasible_score:.4f} - {solution.objective:.4f}"
                           f"| since={global_state.iters_since_last_improvement:6d}"
                           f"| temp={temp:.2f}")
    solutions.append(solution)
    history["opt_state"].append(opt_state)
    history["global_state"].append(global_state)

Flops: 56685


0.4272 - 0.4272| since=  1087| temp=0.00:   0%|          | 119/1000000 [19:53<3831:31:55, 13.80s/it]   

In [None]:
best_score = jax.numpy.array([h.best_score for h in history["global_state"]])
best_feasible_score = jax.numpy.array([h.best_feasible_score for h in history["global_state"]])

plt.semilogy(best_feasible_score, color="tab:blue")
plt.figure()
tree_packing.plot.plot_solution(history["global_state"][-1].best_solution)