In [5]:
from src.evolution.nb_individual_pure import initialize_individual_weights, create_individual_nn_state, \
    get_individual_action

weights = initialize_individual_weights(1, 1, 3, seed=42)
print(f"\nInitialized weights: {len(weights)} values")

# Test neural network operations
nn_state = create_individual_nn_state(1, 3)
action, nn_state = get_individual_action(weights, 0.5, 1.0, nn_state, 1, 1, 3, 1.0)


Initialized weights: 8 values


In [6]:
nn_state

(array([0.03301583]), 1.0, array([-0.06373937,  0.45330248,  0.22404532]), 1.0)

In [7]:
weights

array([ 0.14901425, -0.04147929,  0.19430656,  0.45690896, -0.07024601,
       -0.07024109,  0.47376384,  0.23023042])

In [9]:
import numpy as np
from numba import njit, prange
import time

# Test 1: Sequential version (should always work)
@njit(fastmath=True, cache=True)
def test_sequential(arr):
  result = np.empty(len(arr), dtype=np.float64)
  for i in range(len(arr)):
      result[i] = arr[i] * arr[i] + np.sin(arr[i])
  return result

# Test 2: Parallel version (might crash with SIGSEGV)
@njit(fastmath=True, cache=True, parallel=True)
def test_parallel(arr):
  result = np.empty(len(arr), dtype=np.float64)
  for i in prange(len(arr)):
      result[i] = arr[i] * arr[i] + np.sin(arr[i])
  return result

# Test data
print("Creating test data...")
test_array = np.random.randn(10000).astype(np.float64)

# Test sequential
print("Testing sequential version...")
start = time.time()
result_seq = test_sequential(test_array)
seq_time = time.time() - start
print(f"Sequential: {seq_time:.4f}s - SUCCESS")

# Test parallel (this might crash)
print("Testing parallel version...")
try:
  start = time.time()
  result_par = test_parallel(test_array)
  par_time = time.time() - start
  print(f"Parallel: {par_time:.4f}s - SUCCESS")

  # Check if results match
  if np.allclose(result_seq, result_par):
      print("✅ Results match - Parallel Numba works!")
  else:
      print("❌ Results don't match - Race condition detected")

except Exception as e:
  print(f"❌ Parallel version CRASHED: {e}")

# Check Numba threading info
print("\nNumba threading info:")
try:
  import numba
  print(f"Threading layer: {numba.threading_layer()}")
except:
  print("Could not get threading layer info")


Creating test data...
Testing sequential version...
Sequential: 0.0753s - SUCCESS
Testing parallel version...
Parallel: 0.0788s - SUCCESS
✅ Results match - Parallel Numba works!

Numba threading info:
Threading layer: workqueue
