## Adaptive dynamics of Ising spins in one dimension leveraging Reinforcement Learning

In this work, they focus on the use of RL in active matter systems. Active matter systems are also many-particle systems, consisting of a large number of entities that take up energy from the environment and convert it into drected motion. It is inherently far from equilibrium. One of the interesting characteristics of these systems is the collective coherent motion along a common direction, also called flocking.

### Model

We start with $N$ active Ising spins ($s = \pm 1$) that are randomly distributed along a length $L$ with periodic boundary conditions. To incorporate the RL approach, we define the state and action for each $i^{th}$ spin. The state of each $i^{th}$ spin is determined based on the direction of the spin with respect to its neighbours within the range $[x_i - \delta x, x_i + \delta x]$, where $x_i$ is the position of $i^{th}$ spin and $\delta x$ is the interaction range and chosen as the unit of length in the system. For each $i^{th}$ spin at any instantaneous time $t$, we have two states: $S_i (t) = \pm 1$. If the spin has the same direction as the majority of spins in the range, then its state $S_i (t) = +1$ and $S_i (t) = -1$ for the opposite direction. Further, the action $a_i (t)$ for each spin is either to flip or keep its orientation.

Position update for the spin at each time step $\Delta t$ is according to:

$$x_i (t + \Delta t) = x_i (t) + \tilde{v}_i (t)s_i(t + \Delta t)\Delta t$$

here, $x_i (t)$ and $\tilde{v}_i (t)$ are the position and instantaneous self-propulsion speed of the $i^{th}$ spin at time $t$, respectively, and $s_i (t + \Delta t)$ is the updated spin orientation of the $i^{th}$ spin, and $\tilde{v}_i (t)$ is taken from a uniform distribution with nonzero positive lower and upper bounds $v_1$ and $v_2$, respectively. At every time step, the magnitude of $\tilde{v}_i (t)$ is chosen randomly; hence, at every instance, each particle can take a random step size obtained from the distribution.

The measure of learning is adopted by maintaining the cohesion among the spins. So, the spin receives feedback when it moves to a new position. The spin pays a cost if it loses the number of neighbours around it. Hence, the cost function $C_i (t + \Delta t)$ for each spin is defined as:

$$C_i (t + \Delta t) = \begin{cases}1, \text{ if }n_i(t + \Delta t)<n_i (t) \\ 0, \text{ otherwise}\end{cases}$$

Here, $n_i(t)$ number of spins within the range $[x_i - \delta x, x_i + \delta x]$ at time $t$. The $Q$ matrix is updated at every time step, and initialize it at zeros. At each time, the matrix is updated as:

$$Q_i [S_i(t), a_i(t)] \leftarrow Q_i [S_i(t), a_i (t)](1-\alpha) + \alpha C_i(t + \Delta t)$$

where $\alpha$ is the learning rate. 

Further, the action is chosen based on $\epsilon$-greedy algorithm:

$$a_i(t) = \begin{cases}\text{random action, with probability } \epsilon \\ \text{argmin}Q_i [S_i(t), a_i(t)],\text{ with probability }(1-\epsilon)\end{cases}$$

We choose argmin for the update because the minimum $Q$ value tells that the cost value is also minimum for the corresponding state-action pair. 

To characterize the ordering in the system, we define average magnetization $\braket{m}$ as the order parameter,

$$\braket{m} = \frac{1}{N}\sum_i s_i$$

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.training import train_state
import optax

In [4]:
# Parameters
N = 10  # Number of spins
L = 10.0  # Length of the system
delta_x = 1.0  # Interaction range
v1, v2 = 0.1, 1.0  # Speed bounds
alpha = 0.1  # Learning rate
epsilon = 0.1  # Exploration rate
delta_t = 0.1  # Time step
num_actions = 2  # Number of actions (flip or keep)

# Initialize positions, spins, and velocities
key = jax.random.PRNGKey(0)
positions = jax.random.uniform(key, (N,), minval=0, maxval=L)
spins = jax.random.choice(key, jnp.array([-1, 1]), (N,))
velocities = jax.random.uniform(key, (N,), minval=v1, maxval=v2)

In [5]:
# Define the Q-network
class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=num_actions)(x)
        return x

In [None]:
# Initialize the Q-network
def create_train_state(key, learning_rate):
    q_net = QNetwork()
    params = q_net.init(key, jnp.ones((1, 2)))  # Input shape: (batch_size, num_features)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=q_net.apply, params=params, tx=tx)

In [7]:
key, subkey = jax.random.split(key)
state = create_train_state(subkey, alpha)

def get_neighbors(i, positions, spins):
    left = positions[i] - delta_x
    right = positions[i] + delta_x
    mask = (positions >= left) & (positions <= right)
    return spins[mask]

In [8]:
def update_spin(i, spins, Q, epsilon):
    if jax.random.uniform(key) < epsilon:
        return -spins[i]  # Random action: flip the spin
    else:
        state = jnp.array([spins[i], jnp.sign(jnp.sum(get_neighbors(i, positions, spins)))])
        q_values = Q.apply_fn(Q.params, state)
        action = jnp.argmin(q_values)
        return spins[i] if action == 0 else -spins[i]  # Greedy action based on Q values

In [9]:
def update_position(i, positions, spins, velocities):
    return positions[i] + velocities[i] * spins[i] * delta_t

In [10]:
def cost_function(i, positions, spins):
    neighbors_before = get_neighbors(i, positions, spins)
    new_spin = update_spin(i, spins, state, epsilon)
    new_position = update_position(i, positions, spins, velocities)
    neighbors_after = get_neighbors(i, positions.at[i].set(new_position), spins.at[i].set(new_spin))
    return 1 if len(neighbors_after) < len(neighbors_before) else 0

In [11]:
def update_Q(i, state, cost):
    state_input = jnp.array([spins[i], jnp.sign(jnp.sum(get_neighbors(i, positions, spins)))])
    q_values = state.apply_fn(state.params, state_input)
    action = jnp.argmin(q_values)
    target = q_values.at[action].set(cost)
    loss = jnp.mean((q_values - target) ** 2)
    grads = jax.grad(lambda params: jnp.mean((state.apply_fn(params, state_input) - target) ** 2))(state.params)
    state = state.apply_gradients(grads=grads)
    return state

In [12]:
# Simulation loop
for t in range(100):
    print("Voy en: ", t)
    for i in range(N):
        cost = cost_function(i, positions, spins)
        state = update_Q(i, state, cost)
        spins = spins.at[i].set(update_spin(i, spins, state, epsilon))
        positions = positions.at[i].set(update_position(i, positions, spins, velocities))

# Calculate average magnetization
average_magnetization = jnp.mean(spins)
print(f"Average Magnetization: {average_magnetization}")

Voy en:  0
Voy en:  1
Voy en:  2
Voy en:  3
Voy en:  4
Voy en:  5
Voy en:  6
Voy en:  7
Voy en:  8
Voy en:  9
Voy en:  10
Voy en:  11
Voy en:  12
Voy en:  13
Voy en:  14
Voy en:  15
Voy en:  16
Voy en:  17
Voy en:  18
Voy en:  19
Voy en:  20
Voy en:  21
Voy en:  22
Voy en:  23
Voy en:  24
Voy en:  25
Voy en:  26
Voy en:  27
Voy en:  28
Voy en:  29
Voy en:  30
Voy en:  31
Voy en:  32
Voy en:  33
Voy en:  34
Voy en:  35
Voy en:  36
Voy en:  37
Voy en:  38
Voy en:  39
Voy en:  40
Voy en:  41
Voy en:  42
Voy en:  43
Voy en:  44
Voy en:  45
Voy en:  46
Voy en:  47
Voy en:  48
Voy en:  49
Voy en:  50
Voy en:  51
Voy en:  52
Voy en:  53
Voy en:  54
Voy en:  55
Voy en:  56
Voy en:  57
Voy en:  58
Voy en:  59
Voy en:  60
Voy en:  61
Voy en:  62
Voy en:  63
Voy en:  64
Voy en:  65
Voy en:  66
Voy en:  67
Voy en:  68
Voy en:  69
Voy en:  70
Voy en:  71
Voy en:  72
Voy en:  73
Voy en:  74
Voy en:  75
Voy en:  76
Voy en:  77
Voy en:  78
Voy en:  79
Voy en:  80
Voy en:  81
Voy en:  82
Voy en:  83
Vo

In [13]:
spins

Array([-1, -1,  1,  1, -1, -1, -1, -1, -1, -1], dtype=int32)

In [14]:
positions

Array([-0.6450915 , -0.1382886 ,  9.124096  ,  5.3202004 , -0.59076375,
       -0.263117  , -0.74900085, -0.7226927 , -0.2321845 , -0.13659248],      dtype=float32)