# Multinomial RANCH Project Notebook

#### Author: Guanpeng (Andy) Xu
#### 9.660 Final Project, Fall 2023

## Imports

In [None]:
import numpy as np

import pandas as pd

from scipy.special import gammaln as loggamma, digamma
from matplotlib import pyplot as plt

from tqdm import trange

## Exemplars and Priors

In [None]:
exemplars =  np.array([[0,0,0], [0,0,1], [0,1,2], [1,2,3]]).astype(float) #feature values 0,1,2,3,4
default_exemplars = np.array([np.array([0,0,1]).astype(float) for i in range(6)]).astype(float)

deviant_exemplar = np.array([4,0,0]).astype(float)

full_alphas_1 = np.array([[4, 1, 1, 1, 1],
       [4, 1, 1, 1, 1],
       [4, 1, 1, 1, 1]]).astype(float)
full_alphas_2 = np.array([[3.6, 3.6, 3.6, 3.6, 3.6],
       [3.6, 3.6, 3.6, 3.6, 3.6],
       [3.6, 3.6, 3.6, 3.6, 3.6]]).astype(float) - 2.0

full_alphas_3 = np.array([[0.04, 1.99, 1.99, 1.99, 1.99],
       [0.04, 1.99, 1.99, 1.99, 1.99],
       [0.04, 1.99, 1.99, 1.99, 1.99]]).astype(float)

full_alphas_1

## RANCH implementation

In [None]:
epsilon_ = 0.065
epsilons = [0.001, 0.01, 0.1, 0.2, 0.5, 0.7]
epsilon_hat = 0.065
EIG_env_ = 0.01
numft, numval = full_alphas_1.shape


def perturb(exemplar, epsilon = epsilon_hat):
    rand_arr = np.random.rand(numft)
    out = []
    for ind in range(numft):
        if rand_arr[ind] >= epsilon:
            out.append(exemplar[ind])
        else:
            permute = 1 + int(4* rand_arr[ind]/epsilon) 
            out.append( (exemplar[ind] + permute)%5) 
        pass
    return np.array(out)

def compute_posterior_change(zs, epsilon):
    
    if len(zs) == 0:
        return 1.0 * (np.ones((numval, ))/numval).astype(float)

    
    assert epsilon > 0
    assert epsilon < 1
    ci = np.ones((len(zs), numval)).astype(float) * np.log((epsilon/(numval - 1)))
    
    #each row is a z, each column is a feature value for y.
    ci[[i for i in range(len(zs))], [int(x+ 0.001) for x in zs]] = np.log(1 - epsilon)
    ci_out = np.sum(ci, axis = 0)
    
    assert ci_out.shape[0] == numval
    return np.exp(ci_out)/np.sum(np.exp(ci_out))


def EIG(alpha_priors, zs, epsilon = epsilon_):
    
    M = (1 - (numval * epsilon/(numval-1))) * np.eye(numval) + (epsilon/(numval - 1)) * np.ones((numval, numval))
    
    pi = compute_posterior_change(zs, epsilon)
    
    di = np.matmul(M, pi)
    
    probs = (di+alpha_priors)
    probs = probs/np.sum(probs)
    
    EIG_ = 0
    
    for val in range(numval):
        post_change = compute_posterior_change(zs + [val], epsilon)
        EIG_ += KL(alpha_priors + post_change, alpha_priors + pi) * probs[val]

    return EIG_


def KL(alphas_1, alphas_2):
    
    loggamma1 = loggamma(np.sum(alphas_1))
    loggamma2 = loggamma(np.sum(alphas_2))
    loggamma3 = np.sum(loggamma(alphas_1))
    loggamma4 = np.sum(loggamma(alphas_2))
    diffs = alphas_1 - alphas_2
    digamma_diffs = digamma(alphas_1) - digamma(np.sum(alphas_1))
    
    return (loggamma1 - loggamma2) + (loggamma4 - loggamma3) + np.dot(diffs, digamma_diffs)


def ranch_sample(alpha_priors = full_alphas_1, exemplar_sequence = default_exemplars[:1], epsilon = epsilon_hat, EIG_env = EIG_env_):
    
    """
    Return a sequence of looking times for each exemplar in the sequence, given a set of input priors and epsilon.
    """
    
    priors_modify = alpha_priors.copy()[:,:].astype(float)
    out_array = []
    
    observations_so_far = []
    
    for exemplar in exemplar_sequence:
        
        r = 0 
        
        observations_so_far.append([])
        sample = True
        
        while (sample and r < 5000): 
            z = perturb(exemplar, epsilon)
            
            EIG_next = 0
            
            for ft in range(numft):
                zs_ft = [z[ft] for z in observations_so_far[-1]]
                alphas_ft = priors_modify[ft,:].copy()
                EIG_next += EIG(alphas_ft, zs_ft, epsilon)
                
            if (EIG_env + EIG_next) * np.random.rand() > EIG_next:
                sample = False
            r += 1
            if r == 2500:
                print('2500 iters done. EIG: ', EIG_next)
                
            if r == 5000:
                print('Fail ', EIG_next)
                
            observations_so_far[-1].append(z)
            
        out_array.append(r)
        for ft in range(numft):
            zs_ft = [z[ft] for z in observations_so_far[-1]]
            priors_modify[ft] += compute_posterior_change(zs_ft, epsilon)
    
    return np.array(out_array)




## Multinomial RANCH Properties

In [None]:
seq_dict_e = {}
seq_dict_EIG = {}
seq_dict_alphas = {}

for epsilon__ in epsilons:
    print('EPS: ', epsilon__)
    for l in trange(1000):
        seq_dict_e[epsilon__] = seq_dict_e.get(epsilon__, 0) +  0.001 * ranch_sample(epsilon = epsilon__)
    
print('')
print('') 
for EIG_env__ in [0.001, 0.01, 0.1]:
    print('EIG_env: ', EIG_env__)
    for l in trange(1000):
        seq_dict_EIG[EIG_env__] =seq_dict_EIG.get(EIG_env__, 0) + 0.001 * ranch_sample(EIG_env = EIG_env__)
print('')
print('')  
for i, alpha_arr in enumerate([full_alphas_1, full_alphas_2, full_alphas_3]):
    print('ALPHA_ARR: ', alpha_arr[0])
    for l in trange(1000):
        seq_dict_alphas[i] = seq_dict_alphas.get(i, 0) + 0.001 * ranch_sample(alpha_priors = alpha_arr.copy())


In [None]:
seq_dict_e

In [None]:
seq_dict_EIG

In [None]:
seq_dict_alphas

### Plotting

In [None]:
def plot_dict(dict_):
    names = []
    values = []
    
    for key in dict_:
        names.append(str(key))
        values.append(dict_[key][0])
        
    return names,values

plt.bar(plot_dict(seq_dict_e)[0], plot_dict(seq_dict_e)[1])
plt.xlabel('Modeled Noise $(\epsilon)$' , fontsize = 13)
plt.ylabel('Mean Iterations sampled', fontsize = 13)
plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.show()

plt.bar(plot_dict(seq_dict_EIG)[0], plot_dict(seq_dict_EIG)[1])
plt.xlabel('Modeled EIG(env)', fontsize = 13)
plt.ylabel('Mean Iterations sampled', fontsize = 13)

plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.show()

plt.bar(['Control-Biased Prior', 'Neutral Prior', 'Deviant-Biased Prior'], plot_dict(seq_dict_alphas)[1] )

plt.ylim(3, 4)

plt.ylabel('Mean Iterations sampled', fontsize = 13)
plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.show()

## Complexity/Habituation Analysis

### WARNING:
The below does not replicate outside of my scratch notebook and I do NOT know why.

In [None]:
compl = {}

for i in range(4):
    arr = np.array([exemplars[i].copy() for t in range(6) ])
    print(arr)

    for l in trange(200): 
        compl[str(arr[0] + 1)] = compl.get(str(arr[0] + 1), 0) + 0.005 * ranch_sample(exemplar_sequence = arr)
    

In [None]:
for key in compl:   
    plt.plot([1,2,3,4,5,6], compl[key], label = key)
    
means = []

for i in range(6):
    means.append(0)
    for key in compl:
        means[-1] += 0.25 *compl[key][i]
plt.plot([1,2,3,4,5,6], means, label = 'Mean', color = 'black', linewidth = 2)
plt.legend()  
plt.xlabel('Repetition', fontsize = 13)
plt.ylabel('Looking Iterations', fontsize = 13)

plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.show()

## Dishabituation Analysis

In [None]:
dishab = {}



arr = ([exemplars[1].copy() for t in range(6) ])
print(arr)

for l in trange(1000):
    dishab['Control'] = dishab.get('Control', 0) + 0.001 * ranch_sample(alpha_priors = full_alphas_1, exemplar_sequence = arr)

arr = ([exemplars[1].copy() for t in range(1) ] + [deviant_exemplar.copy() for t in range(1) ] + [exemplars[1].copy() for t in range(4) ] )
print(arr)

for l in trange(1000):
    dishab['Deviant #2'] = dishab.get('Deviant #2', 0) + 0.001 * ranch_sample(alpha_priors = full_alphas_1,exemplar_sequence = arr)


arr = ([exemplars[1].copy() for t in range(3) ] + [deviant_exemplar.copy() for t in range(1) ] + [exemplars[1].copy() for t in range(2) ] )
print(arr)

for l in trange(1000):
    dishab['Deviant #4'] = dishab.get('Deviant #4', 0) + 0.001 * ranch_sample(alpha_priors = full_alphas_1,exemplar_sequence = arr)

    
arr = ([exemplars[1].copy() for t in range(5) ] + [deviant_exemplar.copy() for t in range(1) ] + [exemplars[1].copy() for t in range(0) ] )
print(arr)

for l in trange(1000):
    dishab['Deviant #6'] = dishab.get('Deviant #6', 0) + 0.001 * ranch_sample(alpha_priors = full_alphas_1,exemplar_sequence = arr)


In [None]:
for key in dishab:
    plt.plot([1,2,3,4,5,6],  dishab[key], label = key)
    plt.legend(fontsize = 12)
    
plt.xlabel('Repetition', fontsize = 13)
plt.ylabel('Looking Iterations', fontsize = 13)

plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.show()