In [None]:
import torch
import random
from collections import deque

# Define problem parameters
batch_size = 4  # Maximum number of parallel processes
max_processes = 10  # Total processes to run (this will be n_targets * n_random_restarts)
learning_rate = 0.1

device = 'cpu'
# Pre-allocate tensors for the optimization variables
x_values = torch.rand(batch_size, device=device) * 10  # Random initial values (will be matrix of (n_targets, n_random_restarts))
active_processes = [None] * batch_size  # Track process IDs
available_slots = deque(range(batch_size))  # Free slots

# Function to simulate gradient descent update
def gradient_descent_step(x):
    gradient = 2 * (x - 2)  # Derivative of (x - 2)^2
    return x - learning_rate * gradient  # Update step

# Function to check convergence
def is_converged(x):
    return torch.abs(x - 2) < 0.01

# Function to start a new process

# add all the new values to the gp matrix/vector slices (whitend log latent variables and initial latent variables) 
# whitend log gets determined form initial latent variables
# (other matrices need to be recomputed as always after an itteration)
# start_new_process needs to slice new whitend log latent variables determined from the initial latent variables of that particular restart
# The initial latent variable gets determined by n_random_restart as index of initial values list (Hyperparameter).
# thus in principle can a new process be started by chinging the 3 whitend log latent variables of that slice.
# Need to think about woodbury matrix and mll because they are dependend on the output (target), so that the correct slice has the correct values (works now becaues each slice is a different output (target))
# so it would be wastfull to caclulate it if not in batch, select/calculate corect slicces per process? Just store in the GP which slice corresponds to which output. (make a batch_output matrix)
# also wastfull if the exact same values are calculated multiple times for the same outputs.

def start_new_process(process_id):     
    if available_slots:
        slot_idx = available_slots.popleft()
        x_values[slot_idx] = torch.rand(1, device=device) * 10  # Random restart/ new process
        active_processes[slot_idx] = process_id
        print(f"Starting Process {process_id} in slot {slot_idx}")

# Function to handle completed processes
def remove_converged_process(process_id): # add mll value and save gp (for selecting best models) (now saves full gp, but only information of that slice needs to be saved, how? Only add slice to final gp if llm is larger and save new llm value)
    if process_id in active_processes:
        slot_idx = active_processes.index(process_id)
        active_processes[slot_idx] = None  # Free slot
        available_slots.append(slot_idx)  # Add back to queue
        print(f"Process {process_id} converged, freeing slot {slot_idx}")

# Run the optimization loop
current_process_id = 0
running_processes = min(batch_size, max_processes)

# Initialize first batch
for _ in range(running_processes):
    start_new_process(current_process_id)
    current_process_id += 1

# Optimization loop
while current_process_id < max_processes or any(active_processes):
    # Perform gradient descent step
    x_values = gradient_descent_step(x_values)

    # Check for convergence
    for i, x in enumerate(x_values):
        if active_processes[i] is not None and is_converged(x):
            process_id = active_processes[i]
            remove_converged_process(process_id)

            # Start a new process if there are more to run
            if current_process_id < max_processes:
                start_new_process(current_process_id)
                current_process_id += 1

    print(f"Current values: {x_values.cpu().numpy()}")  # Debugging output

Starting Process 0 in slot 0
Starting Process 1 in slot 1
Starting Process 2 in slot 2
Starting Process 3 in slot 3
Current values: [7.217494  2.7382858 6.280594  2.3138256]
Current values: [6.173995  2.5906286 5.424475  2.2510605]
Current values: [5.339196  2.472503  4.73958   2.2008483]
Current values: [4.671357  2.3780024 4.191664  2.1606786]
Current values: [4.137086  2.302402  3.7533314 2.128543 ]
Current values: [3.7096686 2.2419217 3.4026651 2.1028342]
Current values: [3.367735  2.1935372 3.122132  2.0822673]
Current values: [3.094188  2.1548297 2.8977056 2.0658138]
Current values: [2.8753505 2.1238637 2.7181644 2.052651 ]
Current values: [2.7002804 2.099091  2.5745316 2.0421207]
Current values: [2.5602243 2.0792727 2.4596252 2.0336967]
Current values: [2.4481795 2.0634181 2.3677    2.0269573]
Current values: [2.3585436 2.0507345 2.2941601 2.021566 ]
Current values: [2.286835  2.0405877 2.2353282 2.0172527]
Current values: [2.2294679 2.0324702 2.1882625 2.013802 ]
Current values