In [1]:
import numpy as np
from numba import jit, njit, prange, void, int8, float32
import scipy
from tqdm import tqdm

## Timing setting up the state space

In [2]:
def to_binary(n, width, dtype=np.int8):
    """
    Returns a binary rep of the int n as an array of size width, e.g. Assuming N = 5, 3 -> np.array([0,0,0,1,1]) 
    Not particularly efficient, but since it is only used once at the start for small N, this is okay
    """
    b = np.zeros(width,dtype=dtype)
    for i in range(width):
        if n % 2 == 1: 
            b[width-1-i]=1 # index N-1-i otherwise numbers are reversed
        else:
            b[width-1-i]=0
        n//=2
        if n==0: break
    return b

def get_state_space(width, dtype=np.int8):
    """
    Sets up the state space, but only if analytic expectations are needed
    """
    return np.array([to_binary(n, width, dtype) for n in range(2**width)],dtype=dtype) 

In [3]:
@njit
def n_to_binary(n, width):
    """
    Returns a binary rep of the int n as an array of size width, e.g. Assuming N = 5, 3 -> np.array([0,0,0,1,1]) 
    Not particularly efficient, but since it is only used once at the start for small N, this is okay
    """
    b = np.zeros(width)
    for i in range(width):
        if n % 2 == 1: 
            b[width-1-i]=1 # index N-1-i otherwise numbers are reversed
        else:
            b[width-1-i]=0
        n//=2
        if n==0: break
    return b

@njit
def n_get_state_space(width):
    """
    Sets up the state space, but only if analytic expectations are needed
    """
    space = np.zeros((2**width,width))
    for n in range(2**width):
        space[n] = n_to_binary(n, width)
    return space

In [4]:
%timeit get_state_space(10)
%timeit n_get_state_space(10)

10.2 s ± 438 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
417 ms ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Timing generating Gibbs samples

In [5]:
@njit
def gibbs_sampling(p           : "numba function which returns unnormalized probability of state s",
                   init_states : "CxN matrix of initial states for each Markov chain ",
                   M           : "number of samples per chain",
                   avg_every   : "how many transitions before we take the next sample",
                   burn_in     : "how many samples we initially discard"):
    """
    Returns M x C samples based on the given unnormalised probability mass function p
    Assumes 0,1 notation
    """
    C, N = init_states.shape
    its = M*avg_every+burn_in # number of iterations per chain
    samples = np.zeros((C,its,N)) #initialise samples
    for c in range(C):
        samples[c,0,:] = init_states[c] #set the initial state 
        
    # do gibbs sampling
    for c in range(C):
        for t in range(1,its): 
            samples[c,t] = samples[c,t-1] #copy previous state
            i = t % N #which dimension to work on

            state_off = np.copy(samples[c,t])
            state_on = np.copy(samples[c,t])
            
            state_off[i] =  0 #state with neuron i set to off
            state_on[i] = 1 #state with neuron i set to on
            
            p_off = p(state_off)
            denom = (p_off + p(state_on) )
            if denom == 0: # implies p_off and p_on is zero, so some numerical error
                p_cond_off = np.random.rand(1)[0]
            else:
                p_cond_off = p_off / denom #calc cond prob that spin i is on given other spin vals
            
            if np.random.binomial(1,p_cond_off): #draw number from unif distribution to determine whether we update i
                samples[c,t]=state_off
                continue
            samples[c,t]=state_on
            
    samples = np.copy(samples[:,burn_in::avg_every]) # discard burn in and take every `avg_every` sample
    
    return samples.reshape((C*M,N))

## Helper functions for MC gradient ascent

In [6]:
@njit
def get_ind_samples(avgs,C):
    N = avgs.shape[0]
    states = np.zeros((C,N))
    for c in range(C):
        for n in range(N):
            states[c,n] = np.random.binomial(1,avgs[n]) 
    return states
    

In [7]:
@njit
def p(s, h, J):
    """
    Returns the unnormalized probability (not divided by Z) of the state/states s 
    """
    if s.ndim==1:
        return np.exp( - s.dot(h) - s.dot(J).dot(s) )
    return np.exp(-s.dot(h) - np.sum(s.dot(J)*s, axis=1))

In [8]:
@njit
def sample_corrs(X,Y):
    """
    A convenience method used above
    X is an M x N matrix of states
    Y is an M vector of values for each state
    """
    N = X.shape[1]
    corrs = np.zeros((N,N))
    for i in range(N-1):
        for j in range(i+1,N):
            corrs[i,j] = np.sum( Y[ X[:,i]*X[:,j] == 1 ] )
    return corrs

In [9]:
@njit
def update_weights(i, p, samples, h, J, h_old, J_old, lr):
    n_samples = samples.shape[0]
    if i == 0:
        Ps = np.ones(n_samples) 
    else:
        Ps = p(samples, h-h_old, J-J_old) # like the likelihood of each state, but based on samples 

    denom = n_samples * np.mean(Ps) 
    if denom == 0:
        print("Divide by zero issue", Ps)
        denom = 1e-4 

    mod_avgs = samples.T.dot(Ps) / denom
    mod_corrs = sample_corrs(samples,Ps) / denom

    #update h and J
    h = h + lr*(mod_avgs - avgs)
    J = J + lr*(mod_corrs - corrs)
    return h, J, mod_avgs, mod_corrs
    

In [10]:
def num_gradient_ascent(p               :"function that returns unnorm. prob of states",
                        h               :"vector of weights associated with local fields",
                        J               :"matrix of weights associated with interactions",
                        avgs            :"avgs we wish to reproduce",
                        corrs           :"correlations we wish to reproduce",
                        M               :"number of gibbs samples we create",  
                        C               :"number of chains", 
                        N_sets          :"number of times we generate gibbs samples", 
                        updates_per_set :"number of updates per gibbs sample", 
                        avg_every       :"how many samples we skip",
                        burn_in         :"how many samples we discard", 
                        lr              :"learning rate",
                        seed            :"seed for random components"
                       ):
        """
        Performs gradient ascent, but uses gibbs sampling to work out expectations and reuses samples for multiple updates
        """
        np.random.seed(seed) # set seed
        
        N = h.shape[0] #get no neurons
        
        # get inital states using an independent approximation
        init_states = get_ind_samples(avgs,C)
        
        h_old = h
        J_old = J
        
        save_avs = np.zeros((N_sets, N))
        save_corrs = np.zeros((N_sets, N,N))
        
        for u in range(N_sets):
            @njit
            def p_fix(s):
                return p(s,h,J)
            print("Sampling...")
            samples = gibbs_sampling(p_fix, init_states, M, avg_every, burn_in) #generate set of samples with current h and J
            
            for i in tqdm(range(updates_per_set)): #gradient ascent based on the samples
                h, J, mod_avgs, mod_corrs = update_weights(i, p, samples, h, J, h_old, J_old, lr)
            
            #update h and J that generate samples
            h_old = h
            J_old = J
            init_states = samples[-C:] #take the last C samples as the initial states for next sampling
            
            # save model averages
            save_avs[u] = mod_avgs
            save_corrs[u] = mod_corrs
        
        return h, J, save_avs, save_corrs

In [11]:
@njit
def pert_init(avgs,corrs):
    """
    Initialise weights based on estimates from the perturbative results
    Div by 0 issue if any average is 0
    """
    N = avgs.shape[0]
    h = np.log( (1/avgs) - 1)
    prod_avgs = np.outer(avgs,avgs)
    J = -np.log( (corrs / prod_avgs) + np.tril( np.ones((N,N)))  ) 
    return h,J

In [12]:
N = 20
avgs = np.ones(N)*0.5
corrs = np.ones((N,N))*0.02
h, J = pert_init(avgs,corrs)
M = 1000
C = 2
N_sets = 5
updates_per_set = 50
avg_every = N
burn_in = N
lr=0.1
seed=42

In [21]:
h, J, save_avs, save_corrs = num_gradient_ascent(p, h, J, avgs, corrs, M, C, 10, updates_per_set, avg_every, burn_in, 0.1, seed)

Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 354.63it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 429.38it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 432.80it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 446.44it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 397.50it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 421.15it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 437.04it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 476.45it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 462.75it/s]


Sampling...


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 456.59it/s]


In [22]:
save_avs

array([[0.66887833, 0.56276053, 0.        , 0.40323453, 0.        ,
        0.        , 0.        , 0.33299063, 0.33436982, 0.33563848,
        0.        , 0.31362782, 0.        , 0.28313765, 0.66436152,
        1.        , 1.        , 0.        , 0.43723947, 0.        ],
       [0.        , 0.6853146 , 0.        , 1.        , 0.        ,
        0.        , 0.14527321, 0.34626396, 1.        , 1.        ,
        0.65904622, 0.        , 0.        , 0.14723715, 0.        ,
        0.33955248, 0.66044752, 0.        , 0.65564363, 0.        ],
       [0.        , 0.        , 0.        , 1.        , 0.        ,
        0.        , 0.64011338, 0.35200828, 0.76820389, 0.5685315 ,
        0.67114295, 0.        , 0.        , 0.        , 0.40831727,
        0.        , 0.59168273, 0.        , 1.        , 0.        ],
       [0.49833497, 0.29245143, 0.        , 0.32011875, 0.        ,
        0.44602279, 0.26244636, 0.56042509, 0.4240205 , 0.84620204,
        0.        , 0.16978621, 0.        , 0