# Ising model

The Hamiltonian of the Ising chain is given by:

$$\hat{H}_I = J\sum_{\braket{i,j}}\sigma_i \sigma_j$$

In [2]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state
from collections import deque
import random

# Class that creates an Ising environment
class IsingEnv:
    def __init__(self, N, T, key=jax.random.key(0)):
        self.N = N # Number of spins in the Ising chain
        self.T = T # Temperature of the system
        self.key, _key = jax.random.split(key)
        self.spins = jax.random.choice(_key, jnp.array([-1, 1]), shape=(self.N,)) # Array that keeps the spin states
        
    
    # Function that resets the spin state to a random state
    def reset(self):
        self.key, _key = jax.random.split(self.key)
        self.spins = jax.random.choice(_key, jnp.array([-1, 1]), shape=(self.N,))
        return self.spins
    
    # Returns the energy of the current state
    def energy(self, spins):
        return -jnp.sum(spins * jnp.roll(spins, 1))
    
    # Flips spin at random
    def step(self, action):
        new_spins = self.spins.at[action].multiply(-1)
        delta_E = self.energy(new_spins) - self.energy(self.spins)
        
        # Boltzmann reward (negative free energy difference)
        reward = jnp.exp(-delta_E / self.T)
        self.reward = reward
        self.spins = new_spins
        return new_spins, reward  # (state, reward)

In [3]:
class QNetwork(nn.Module):
    num_actions: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        q_values = nn.Dense(features=self.num_actions)(x)
        return q_values

In [4]:
# Implements the DQN algorithm
class DQNAgent:
    def __init__(self, state_dim, key=jax.random.key(0), epsilon=0.1, learning_rate=1e-3, gamma=0.99, batch_size=32, buffer_size=10000):
        self.state_dim = state_dim # Dimensionality of the state space (number of Isning spins)
        self.num_actions = state_dim # Number of actions in the action space
        self.learning_rate = learning_rate
        self.gamma = gamma # Discount factor of future rewards
        self.batch_size = batch_size # Number of samples to take from the replay buffer
        self.buffer = deque(maxlen=buffer_size) # Replay buffer
        self.epsilon = epsilon
        self.key, _key = jax.random.split(key)
        
        self.q_network = QNetwork(num_actions=state_dim) # Q-network
        self.target_network = QNetwork(num_actions=state_dim) # Target network
        
        self.params = self.q_network.init(_key, jnp.ones(state_dim)) # Initialize the Q-network parameters
        self.target_params = self.params # Initialize the target network parameters
        
        self.optimizer = optax.adam(learning_rate) # Adam optimizer
        self.state = train_state.TrainState.create(apply_fn=self.q_network.apply, params=self.params, tx=self.optimizer) # Training state
    
    # Epsilon-greedy policy to select actions
    def select_action(self, state):
        self.key, _key, _key2 = jax.random.split(self.key, 3)
        if jax.random.uniform(_key) < self.epsilon:
            return jax.random.choice(_key2, jnp.arange(self.num_actions))
        q_values = self.q_network.apply(self.state.params, state)
        return jnp.argmax(q_values)
    
    # Store transition in the replay buffer
    def store_transition(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))
    
    # Update the Q-network parameters (given that there is more than batch_size samples in the replay buffer)
    def update(self):
        if len(self.buffer) < self.batch_size:
            return
        
        batch = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states = zip(*batch)
        
        states = jnp.array(states)
        actions = jnp.array(actions)
        rewards = jnp.array(rewards)
        next_states = jnp.array(next_states)
        
        # Loss function
        def loss_fn(params):
            q_values = self.q_network.apply(params, states)
            q_values = jnp.take_along_axis(q_values, actions[:, None], axis=1).squeeze()
            
            next_q_values = self.target_network.apply(self.target_params, next_states)
            max_next_q_values = jnp.max(next_q_values, axis=1)
            targets = rewards + self.gamma * max_next_q_values
            
            loss = jnp.mean((q_values - targets) ** 2)
            return loss
        
        grads = jax.grad(loss_fn)(self.state.params)
        self.state = self.state.apply_gradients(grads=grads)

    # Update the target network parameters 
    def update_target_network(self):
        self.target_params = self.state.params

In [5]:
# Function to train the DQN agent
def train_dqn(env, agent, num_samples=1_000, num_episodes=100, epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=0.995):
    env.epsilon = epsilon_start
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        
        for _ in range(num_samples):
            action = agent.select_action(state)
            next_state, reward = env.step(action)
            
            agent.store_transition(state, action, reward, next_state)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        agent.update_target_network()
        env.epsilon = max(epsilon_end, env.epsilon * epsilon_decay)
        
        print(f"Episode {episode + 1}, Total Reward: {total_reward}")

# Example usage
N = 5  # Number of spins
T = 1.0  # Temperature
env = IsingEnv(N, T)
agent = DQNAgent(state_dim=N)

train_dqn(env, agent)

Episode 1, Total Reward: 17100.744140625
Episode 2, Total Reward: 24994.0234375
Episode 3, Total Reward: 26518.8828125
Episode 4, Total Reward: 26097.9609375
Episode 5, Total Reward: 26676.728515625
Episode 6, Total Reward: 26362.01953125
Episode 7, Total Reward: 26835.5546875
Episode 8, Total Reward: 26254.82421875
Episode 9, Total Reward: 26204.171875
Episode 10, Total Reward: 26363.001953125
Episode 11, Total Reward: 26625.09375


KeyboardInterrupt: 

In [6]:
def sample_states(env, agent, num_samples=100):
    states = []
    state = env.reset()
    
    for _ in range(num_samples):
        action = agent.select_action(state)
        next_state, _ = env.step(action)
        states.append(next_state)
        state = next_state
    
    return states

# Example usage
sampled_states = sample_states(env, agent, num_samples=100)
print(sampled_states)



[Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1,  1, -1, -1, -1], dtype=int32), Array([-1,  1, -1,  1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1, -1, -1, -1], dtype=int32), Array([-1, -1, -1,  1, -1], dtype=int32), Array([-1, -1,  1,  1, -1], dtype=int32), Array([ 1, -1,  1,  1, -1], dtype=int32), Array([ 1,  1,  1,  1, -1], dtype

In [12]:
def mean_energy(env, sampled_states):
    energies = jnp.array([env.energy(state) for state in sampled_states])
    mean_energy = jnp.mean(energies)
    return mean_energy

# Example usage
mean_energy_value = mean_energy(env, sampled_states)
print(f"Mean Energy: {mean_energy_value}")

Mean Energy: -2.6399998664855957


In [10]:
def exact_ground_energy(N):
    # For the 1D Ising model, the ground state energy is simply -N
    return -N

# Example usage
ground_energy = exact_ground_energy(N)
print(f"Exact Ground Energy: {ground_energy}")

(100, 5)

In [31]:
jax.random.randint(jax.random.PRNGKey(0), minval=0, maxval=10, shape=1)

Array([2], dtype=int32)

In [15]:
def train_dqn(env, agent, num_episodes=1000, epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=0.995):
    epsilon = epsilon_start
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            action = agent.select_action(state, epsilon)
            next_state, reward = env.step(action)
            done = jnp.array_equal(next_state, state)  # Example condition for done
            
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        agent.update_target_network()
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        
        print(f"Episode {episode + 1}, Total Reward: {total_reward}")

# Example usage
N = 10  # Number of spins
T = 1.0  # Temperature
env = IsingEnv(N, T)
agent = DQNAgent(state_dim=(N,), num_actions=N)

train_dqn(env, agent)

TypeError: DQNAgent.__init__() got an unexpected keyword argument 'num_actions'

In [10]:
class QNetwork(nn.Module):
    action_dim: int
    
    @nn.compact
    def __call__(self, spins, T):
        x = jnp.concatenate([spins, jnp.array([T])])
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        return nn.Dense(self.action_dim)(x)

# Initialize model
N = 16  # Example size
model = QNetwork(action_dim=N)
params = model.init(jax.random.PRNGKey(0), jnp.ones(N), 1.0)
optimizer = optax.adam(1e-3)

In [16]:
import jax
import jax.numpy as jnp
import optax
import numpy as np
from collections import deque

# Hyperparameters
BATCH_SIZE = 32
BUFFER_CAPACITY = 10000
GAMMA = 0.99
TARGET_UPDATE = 100
EPS_START, EPS_END, EPS_DECAY = 1.0, 0.1, 0.995

class DQNTrainer:
    def __init__(self, model, params, N):
        self.model = model
        self.params = params
        self.target_params = params
        self.optimizer = optax.adam(1e-3)
        self.opt_state = self.optimizer.init(params)
        self.replay_buffer = deque(maxlen=BUFFER_CAPACITY)
        self.N = N
        self.epsilon = EPS_START
        self.metrics = {
            'episode_rewards': [],
            'avg_energy': [],
            'epsilon': []
        }

    @jax.jit
    def _update(self, params, target_params, opt_state, batch):
        states, Ts, actions, rewards, next_states, dones = batch
        
        def loss_fn(params):
            q = self.model.apply(params, states, Ts)
            q = q[jnp.arange(q.shape[0]), actions]
            
            next_q = self.model.apply(target_params, next_states, Ts)
            max_next_q = jnp.max(next_q, axis=1)
            targets = rewards + GAMMA * max_next_q * (1 - dones)
            
            return jnp.mean((q - targets)**2)
        
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, new_opt_state = self.optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss

    def train_step(self, transition):
        self.replay_buffer.append(transition)
        
        if len(self.replay_buffer) >= BATCH_SIZE:
            indices = np.random.choice(len(self.replay_buffer), BATCH_SIZE)
            batch = [self.replay_buffer[i] for i in indices]
            
            # Explicit unpacking for clarity
            states = jnp.stack([t[0] for t in batch])
            actions = jnp.array([t[1] for t in batch])
            rewards = jnp.array([t[2] for t in batch])
            next_states = jnp.stack([t[3] for t in batch])
            dones = jnp.array([t[4] for t in batch])
            Ts = jnp.array([t[5] for t in batch])
            
            self.params, self.opt_state, _ = self._update(
                self.params, self.target_params, self.opt_state,
                (states, Ts, actions, rewards, next_states, dones)
            )
            
            self.epsilon = max(EPS_END, self.epsilon * EPS_DECAY)
            
            if self.train_steps % TARGET_UPDATE == 0:
                self.target_params = self.params
            self.train_steps += 1
            
    def record_metrics(self, episode_reward, avg_energy):
        self.metrics['episode_rewards'].append(episode_reward)
        self.metrics['avg_energy'].append(avg_energy)
        self.metrics['epsilon'].append(self.epsilon)

In [17]:
from collections import deque
import csv
from datetime import datetime
from flax import linen as nn

def run_training(num_episodes=1000):
    N = 16
    model = QNetwork(action_dim=N)
    params = model.init(jax.random.PRNGKey(0), jnp.ones(N), 1.0)
    trainer = DQNTrainer(model, params, N)
    max_steps = 100  # Steps per episode
    
    with open(f'metrics_{datetime.now().strftime("%Y%m%d-%H%M")}.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['Episode', 'AvgReward', 'AvgEnergy', 'Epsilon'])
        
        for ep in range(num_episodes):
            key = jax.random.PRNGKey(ep)
            T = np.random.uniform(0.5, 2.0)
            env = IsingEnv(N, T)
            state = env.reset(key)
            total_reward = 0.0
            energy_accum = 0.0
            
            for _ in range(max_steps):
                if np.random.rand() < trainer.epsilon:
                    action = np.random.randint(N)
                else:
                    q_vals = model.apply(trainer.params, state, T)
                    action = jnp.argmax(q_vals).item()
                
                next_state, reward, _, _ = env.step(action)
                energy = env.energy(next_state)
                trainer.train_step((state, action, reward, next_state, False, T))
                
                total_reward += reward
                energy_accum += energy
                state = next_state

            avg_reward = total_reward / max_steps
            avg_energy = energy_accum / max_steps
            trainer.record_metrics(avg_reward, avg_energy)
            
            # Log to CSV every episode
            writer.writerow([ep, avg_reward, avg_energy, trainer.epsilon])
            
            # Print summary every 50 episodes
            if ep % 50 == 0:
                print(f"Ep {ep}: Reward {avg_reward:.3f} | Energy {avg_energy:.1f} | ε {trainer.epsilon:.3f}")

if __name__ == "__main__":
    run_training(num_episodes=1000)

TypeError: Error interpreting argument to <function DQNTrainer._update at 0x17f8cef20> as an abstract array. The problematic value is of type <class '__main__.DQNTrainer'> and was passed to the function at path self.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [15]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('metrics_*.csv')
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.plot(df['AvgReward'])
plt.title('Average Reward')
plt.subplot(132)
plt.plot(df['AvgEnergy'])
plt.title('System Energy')
plt.subplot(133)
plt.plot(df['Epsilon'])
plt.title('Exploration Rate')
plt.tight_layout()
plt.show()

ModuleNotFoundError: No module named 'pandas'