# Imports etc.

In [None]:
import ipdb
import numpy as np
import os
import pandas as pd
import plotnine as gg
import scipy
gg.theme_set(gg.theme_bw)

In [None]:
fitted_param_dir = 'C:/Users/maria/MEGAsync/SLCN/PShumanData/fitting/mice/'
plot_dir = 'C:/Users/maria/MEGAsync/SLCN/models/plots'

# Load fitted params

In [None]:
fitted_params = pd.DataFrame()
modelnames = [f for f in os.listdir(fitted_param_dir) if ('.csv' in f) and ('params' in f)]
for modelname in modelnames:
    model_params = pd.read_csv(os.path.join(fitted_param_dir, modelname))
    model_params['model'] = modelname.split('_')[1]
    fitted_params = fitted_params.append(model_params, sort=False)
fitted_params

## Plot

In [None]:
fitted_params_ = fitted_params.copy()
fitted_params_['beta'] /= 20
fitted_params_long = fitted_params_.melt(
    id_vars=['sID', 'slope_variable', 'fullID', 'animal', 'PreciseYrs', 'Gender', 'treatment',
             'session', 'age_z', 'T1', 'PDS', 'model'],
    value_name='param_value', var_name='param'
)

In [None]:
g = (gg.ggplot(fitted_params_long, gg.aes('param', 'param_value'))
 + gg.stat_summary(geom='bar')
 + gg.geom_point(position='jitter', alpha=0.2)
 + gg.facet_wrap('~ model')
)
g.draw()
g.save(os.path.join(plot_dir, 'fitted_params.png'))

# Simulate data

In [None]:
params = pd.DataFrame({
    'alpha': [0.8], 'nalpha': [0.1], 'calpha': [0.9], 'cnalpha': [0.1],
    'beta': [4], 'persev': [0.2], 'bias': [0]
})
init_Q = 1/2
n_agents = 10

In [None]:
class PSAgent():
    
    def __init__(self, n_agents, params, init_Q, eps=1e-5):
        self.n_agents = n_agents
        self.params = params
        self.avail_actions = (0, 1)
        self.init_Q = init_Q
        self.Q = init_Q * np.ones((n_agents, len(self.avail_actions)))
        self.eps = eps
        self.prev_action = np.full(n_agents, np.nan)
        
    def take_action(self):
        """
        Take 1 action per agent, based on current trial Q-values.
        """
        
        ags = np.arange(self.n_agents)

        # Perseveration
        Q0 = self.Q[ags, 0]
        Q1 = self.Q[ags, 1]
        if not np.isnan(self.prev_action[0]):  # prev_action is np.nan on trial 0 only; only checking first element for simplicity
            Q0 += (1 - self.prev_action) * self.params['persev'].values  # action 0: subtract persev when repeating
            Q1 += self.prev_action * self.params['persev'].values   # action 1: add persev when repeating

        # Action selection
        lik = scipy.special.softmax(self.params['beta'].values * np.array([Q0, Q1]).T, axis=1)
        action = np.array([np.random.choice(self.avail_actions, p=lik[a]) for a in range(self.n_agents)])
        self.prev_action = action.copy()
        
#         lik = self.eps / 2 + (1 - self.eps) + lik  # squeeze between eps and 1-eps to avoid 0's and 1's
        
        return lik, action
    
    def update_Q(self, action, reward):
        """
        Update Q-values based on RL.
        """
        
        ags = np.arange(self.n_agents)
        
        rpe = (1 - self.Q[ags, action]) * reward  # received reward, updating chosen action
        nrpe = (0 - self.Q[ags, action]) * (1 - reward)  # received no reward, updating chosen action
        
        crpe = (0 - self.Q[ags, 1-action]) * reward  # received reward, updating unchosen action
        cnrpe = (1 - self.Q[ags, 1-action]) * (1 - reward)  # received no reard, updating unchosen action
        
        self.Q[ags, action] += self.params['alpha'].values * rpe + self.params['nalpha'].values * nrpe
        self.Q[ags, 1-action] += self.params['calpha'].values * crpe + self.params['cnalpha'].values * cnrpe


# Example use:
agent = PSAgent(n_agents, params, init_Q)
lik, action = agent.take_action()
print("action", action)
correct, reward = task.present_reward(action, trial)
print("reward", reward)
agent.update_Q(action, reward)
print("agent.Q", agent.Q)

In [None]:
n_trials = 200
p_cor = 0.75
block_lengths_lower = 40
block_lengths_upper = 41

In [None]:
class PSTask():
    
    def __init__(self, p_cor, correct_actions, n_trials, block_lengths_lower, block_lengths_upper):
        """
        Must either provide block_lengths_lower and block_lengths_upper -> task will be created on the fly;
        or correct_actions -> provided task will be used.
        """
        
        self.p_cor = p_cor

        if len(correct_actions) > 0:
            self.correct_actions = correct_actions
            self.n_trials = len(correct_actions)
        elif block_lengths_lower: 
            self.block_lengths_lower = block_lengths_lower
            self.block_lengths_upper = block_lengths_upper
            self.n_trials = n_trials
            self.correct_actions = self.make_task()
        else:
            raise ValueError("You must provide either correct_actions or block_lengths_lower.")
        
    def make_task(self):
        """
        Currently just produces the same sequence of correct and incorrect boxes for each animal.
        In future, will read in animal data.
        """
        
        correct_actions = []
        block_lengths = np.random.randint(
            low=self.block_lengths_lower, high=self.block_lengths_upper, size=self.n_trials)

        for block_length, correct_side in zip(block_lengths, [0, 1] * self.n_trials):
            correct_actions += block_length * [correct_side]
        
        correct_actions = correct_actions[:self.n_trials]
        
        return correct_actions
        
    def get_chance_rewards(self, n_correct_choices):
        """
        Translate accuracy into rewards:
        Return '1' with probability self.p_cor and '0' with probability 1-self.p_cor, for each agent.
        """
        
        return np.array([np.random.choice((0, 1), p=(1-self.p_cor, self.p_cor)) for i in range(n_correct_choices)])

    def present_reward(self, action, trial):
        """
        Present reward (0, 1) for each agent in this trial, based on choices,
        by consulting the correct_box on the current trial.
        """
        
        correct = np.array(self.correct_actions[trial] == action).astype(int)
        reward = correct.copy()
        reward[reward == 1] = self.get_chance_rewards(sum(reward==1))
        
        return correct, reward
    
# Example use
task = PSTask(
    p_cor, n_trials=n_trials, block_lengths_lower=block_lengths_lower, block_lengths_upper=block_lengths_upper,
    correct_actions=[])
task.make_task()

task.get_chance_rewards(n_correct_choices=100)

action = np.zeros(n_agents)
trial = 0
task.present_reward(action, trial)

In [None]:
def simulate_dataset(task_args, agent_args):
    
    actions = []
    liks = []
    rewards = []
    corrects = []

    # Get task and agent
    task = PSTask(task_args['p_cor'], n_trials=task_args['n_trials'],
                  block_lengths_lower=task_args['block_lengths_lower'], block_lengths_upper=task_args['block_lengths_upper'],
                  correct_actions=task_args['correct_actions'])
    agent = PSAgent(agent_args['n_agents'], agent_args['params'], agent_args['init_Q'])

    # Play the game, save data
    for trial in range(n_trials):

        lik, action = agent.take_action()
        correct, reward = task.present_reward(action, trial)
        agent.update_Q(action, reward)

        actions += [action]
        liks += [lik]
        rewards += [reward]
        corrects += [correct]

    # Format data
    data = pd.DataFrame(
            {'action': actions, 'lik': liks,
             'reward': rewards, 'correct': corrects, 'correct_action': task.correct_actions})
    data = data.reset_index()
    data = data.rename(columns={'index': 'trial'})

    data['mean_reward'] = np.mean(np.array(list(data.reward)), axis=1)
    data['mean_correct'] = np.mean(np.array(list(data.correct)), axis=1)
    data['block'] = np.append([0], np.cumsum(np.abs(np.diff(task.correct_actions))))

    return data

# Example use
task_args = {
    'p_cor': 0.75, 'correct_actions': [], 'n_trials': n_trials,
    'block_lengths_lower': block_lengths_lower, 'block_lengths_upper': block_lengths_upper
}
agent_args = {
    'n_agents': n_agents, 'params': params, 'init_Q': init_Q,
}

data = simulate_dataset(task_args, agent_args)
data

In [None]:
(gg.ggplot(data, gg.aes('trial', 'mean_correct', color='block'))
 + gg.geom_point()
 + gg.geom_line()
)

# Simulated mouse data from fitted params

In [None]:
param_names = ['alpha', 'nalpha', 'calpha', 'cnalpha', 'beta', 'persev', 'bias']
model_name = 'RLab'
n_agents = 10

In [None]:
true_dat

In [None]:
fitted_params

In [None]:
model_names = [modelname.split('_')[1] for modelname in modelnames]
for model_name in model_names:
    print(model_name, model_names)
    sim_data = pd.DataFrame()
    
    for animal, age in zip(animals, ages):
        if (animal != 23) and (age != 43):  # Error in the Juvi_AnimalID.csv - rerunning fitted_params with fixed one

            # Get task for this mouse
            true_sub = true_dat.loc[(true_dat.age == age) & (true_dat.animal == animal)]
            n_trials = len(np.unique(true_sub.trial))
            correct_actions = true_sub.correct_action.values

            # Get params
            params = fitted_params.loc[
                (fitted_params.PreciseYrs == age) & (fitted_params.animal == animal) & (fitted_params.model == model_name),
                param_names]
            task_args = {
                'p_cor': 0.75, 'correct_actions': correct_actions, 'n_trials': 0,
                'block_lengths_lower': False, 'block_lengths_upper': False
            }
            agent_args = {
                'n_agents': n_agents, 'params': params, 'init_Q': init_Q
            }

            sub_data = simulate_dataset(task_args, agent_args)
            sub_data['session'] = true_sub.session[0]
            sub_data['animal'] = animal
            sub_data['age'] = age
            sub_data['model'] = model_name

            sim_data = sim_data.append(sub_data)
    
    save_dir = os.path.join(fitted_param_dir, 'simulations/simulated_mice_{}_nagents{}.csv'.format(model_name, n_agents))
    print("Saving sim_data ({}) to {}...".format(sim_data.shape, save_dir))
    sim_data.to_csv(save_dir)

In [None]:
# Super basic check
gg.options.figure_size = (10, 10)
(gg.ggplot(sim_data, gg.aes('trial', 'mean_correct', color='block'))
 + gg.stat_summary()
)

# Next steps
* Simulate all mice, all sessions, based on fitted parameters
* Save as csvs
* Option A) Read into R and analyze in the same way as humans
* Option B) Reimplement the analyses in python and analyze there
* Bring actual mouse data into the dame shape and analyze in the same way
* Compare models
* Average over mice