In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
from collections import deque
from functools import partial

# Ising model environment
# Following the markov decision process thingui, this must store the state of the ising chain,
# and an agent (externo) can take actions (flip spins), and then the env returns reward and the new state

class t_ising_env:
    def __init__(self, N, J, g, seed: int = 42):
        self.J = J # interchange thingui
        self.g = g # external field
        self.N = N # number of spin sites
        self.key = jax.random.PRNGKey(seed)
        self.state = self.reset() # initial state at random
    
    # initialize spins randomly
    def reset(self):
        self.key, subkey = jax.random.split(self.key)
        state = jax.random.choice(subkey, jnp.array([-1, 1]), shape = (self.N,))
        return state
    
    # computes local energy
    def local_energy(self, state):
        local = jnp.sum(pos*jnp.roll(pos, 1))*(-J) + \
                jnp.sum(jnp.array([func(pos.at[i].multiply(-1)) for i in range(n_sites)]))*(-g)/func(pos)
        return -jnp.sum(state * jnp.roll(state, shift=-1))
    
    # take a step ie flip a spin and returns new state and reward
    
    def step(self, action):
        new_state = self.state.at[action].set(-self.state[action])
        delta_H = self.energy(new_state) - self.energy(self.state)

        accept_prob = jnp.minimum(1.0, jnp.exp(-self.beta * delta_H))
        self.key, subkey = jax.random.split(self.key)
        flip = jax.random.uniform(subkey) < accept_prob
        
        self.state = jnp.where(flip, new_state, self.state)
        reward = jnp.where(flip, -delta_H, 0.)

        return self.state, reward