# FOV model
## renew on 2024-4-1

April Fool's Day Celebration Codes

## Setup

In [1]:
import arviz as az
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np
import pandas as pd
import seaborn as sns
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
import random
az.style.use("arviz-grayscale")

### Load data

Allfishtrial = pd.read_csv('tips.csv')
Allfishtrial.tail()

BRWR_fishid = pd.Categorical(Allfishtrial['fish_id'],
                     categories=['fish3', 'fish7', 'fish9', 'fish10']).codes
BRWR_sessiontype = pd.Categorical(Allfishtrial['session_type'],
                     categories=['Baseliine', 'Rotation', 'Washout', 'ReRotation']).codes


groups_fishid = len(np.unique(BRWR_fishid))


Priors calculation， should get these from raw data

In [2]:
first_day_value = 1
first_trial_value = 1

In [30]:
def gamma_shra_from_modesd(m, s):
    v = s**2
    ra = (m + np.sqrt(m**2 + 4*v)) / (2*v)
    sh = 1+m*ra
    return sh,ra

In [9]:
stage_list = ['baseline', 'rotation', 'washout', 'savings', 'washout 2']

sim = {
    # The amount of learning across fish (for each stage)
    "μ_Aμμ": [-1, 5, -1, 5, -1],
    "σ_Aμμ": [5, 5, 5, 5, 5],

    # Time constant of learning across fish (for each stage)
    "mode_τμ": [4, 4, 4, 4, 4],
    "σ_τμ": [1.5, 1.5, 1.5, 1.5, 1.5],

    # End point of learning across fish (for each stage) 
    "μ_μμ_inf": [0, 0, 0, 0, 0],
    "σ_μμ_inf": [1, 1, 1, 1, 1],

    # Change in std of days around mean value across fish (for each stage)
    "m_Aσμ": [3, 3, 3, 3, 3],
    "σ_Aσμ": [1, 1, 1, 1, 1],

    # Final std of days around mean value across fish (for each stage) 
    "m_σμ_inf": [0.5, 0.5, 0.5, 0.5, 0.5],
    "σ_σμ_inf": [0.5, 0.5, 0.5, 0.5, 0.5],
    
    # Mode of change in std of trials around day across fish (for each stage)
    "m_Amσ": [20, 10, 10, 10, 5],
    "m_Amσ": [10, 10, 10, 10, 10],

    # Time constant of variance reduction across fish (for each stage)
    "mode_τσ": [3, 3, 3, 3, 3],
    "σ_τσ": [1.5, 1.5, 1.5, 1.5, 1.5],
    
    # Mode of final std of trials around day across fish (for each stage)
    "μ_mσ_inf": [5, 5, 5, 5, 5],
    "σ_mσ_inf": [3, 3, 3, 3, 3],

    # Std of final std of trials around day across fish (for each stage)
    "m_Aσσ": [2, 2, 2, 2, 2],
    "σ_Aσσ": [1, 1, 1, 1, 1],
}

sim_df = pd.DataFrame(index=stage_list, columns=sim.keys())
for k,v in sim.items():
    sim_df[k] = v

In [10]:
sim_df

Unnamed: 0,μ_Aμμ,σ_Aμμ,mode_τμ,σ_τμ,μ_μμ_inf,σ_μμ_inf,m_Aσμ,σ_Aσμ,m_σμ_inf,σ_σμ_inf,m_Amσ,mode_τσ,σ_τσ,μ_mσ_inf,σ_mσ_inf,m_Aσσ,σ_Aσσ
baseline,-1,5,4,1.5,0,1,3,1,0.5,0.5,10,3,1.5,5,3,2,1
rotation,5,5,4,1.5,0,1,3,1,0.5,0.5,10,3,1.5,5,3,2,1
washout,-1,5,4,1.5,0,1,3,1,0.5,0.5,10,3,1.5,5,3,2,1
savings,5,5,4,1.5,0,1,3,1,0.5,0.5,10,3,1.5,5,3,2,1
washout 2,-1,5,4,1.5,0,1,3,1,0.5,0.5,10,3,1.5,5,3,2,1


In [27]:
fish_list = [4, 6, 7, 8, 10]
days = {'baseline': 27, 'rotation': 16, 'washout': 10, 'savings': 16, 'washout 2': 10}

# Initialize an empty list to store the data
data = []

# Generate data for each fish, epoch, day, and trial
for f in fish_list:
    for stage in stage_list:
        sim_s = sim_df[stage]
        
        # Parameters for mean of days
        A_mumu = np.random.normal(loc=sim_s["μ_Aμμ"], scale=sim_s["σ_Aμμ"])
        tau_sh, tau_ra = gamma_shra_from_modesd(sim_s["mode_τμ"], sim_s["σ_τμ"])
        tau_mu = np.random.gamma(shape=tau_sh, scale=1/tau_ra)
        mumu_inf = np.random.normal(loc=sum_s["μ_μμ_inf"], scale=sim_s["σ_μμ_inf"])        
        
        # Parameters for std of days
        A_sh, A_ra = gamma_shra_from_modesd(sim_s["m_Aσμ"], sim_s["σ_Aσμ"])
        A_sdmu = np.random.gamma(shape=A_sh, scale=1/A_ra)
        sdmu_inf_sh, sdmu_inf_ra = gamma_shra_from_modesd(sim_s["m_σμ_inf"], sim_s["σ_σμ_inf"])
        sdmu_inf = np.random.gamma(shape=sdmu_inf_sh, sdmu_inf_ra)
        
        # Parameters for std of trials within day
        A_


        for day in range(first_day_value, days[stage] + first_day_value):
            num_trials = np.random.randint(4, 8)  # Randomly choose number of trials per day
            day_mu_mu = A_mumu*np.exp(-day/tau_mu) + mumu_inf
            day_mu_sd = A_sdmu*np.exp(-day/tau_mu) + sdmu_inf
            
            day_sd_mu = A_sdmu*np.exp(-day/tau_mu) + sdmu_inf
            
            day_mu = np.random.normal(loc=day_mu_mu, scale=day_mu_sd)
            day_sd = np.random.gamma()
            for trial in range(first_trial_value, num_trials + first_trial_value):
                yds = np.random.normal(loc=day_mu, scale=30)  # Generate random yds
                data.append([f, stage, day, trial, yds])

# Create a pandas DataFrame
df = pd.DataFrame(data, columns=['fish', 'stage', 'day', 'trial', 'yds'])

num_data = df.shape[0]

In [28]:
df.head()

Unnamed: 0,fish,stage,day,trial,yds
0,4,baseline,1,1,24.598339
1,4,baseline,1,2,11.392468
2,4,baseline,1,3,-12.660974
3,4,baseline,1,4,-23.122066
4,4,baseline,1,5,-11.186687


In [29]:
num_data

2176

## Define model

### Priors

#### μ_μ hyper priors

In [31]:
# Priors for Aμμ
df_avg = df.groupby(['fish', 'stage', 'day']).agg({'yds': 'mean'}).reset_index()

mean_first_day = df_avg[df_avg['day'] == first_day_value]
max_day_indices = df_avg.groupby(['fish', 'stage'])['day'].idxmax()
mean_last_day = df_avg.loc[max_day_indices]

merged_data = pd.merge(mean_first_day, mean_last_day, on=['fish', 'stage'], suffixes=('_first', '_last'))
merged_data['difference'] = merged_data['yds_first'] - merged_data['yds_last']

μ_Aμμ = merged_data['difference'].mean()
σ_Aμμ = merged_data['difference'].std()

# Priors for τ_μμ
merged_data['halfway_yds'] = (merged_data['yds_first'] + merged_data['yds_last']) / 2
df_avg = pd.merge(df_avg, merged_data[['fish', 'stage', 'halfway_yds']], on=['fish', 'stage'])

df_avg_below_halfway = df_avg[df_avg['yds'] < df_avg['halfway_yds']].groupby(['fish', 'stage']).agg({'day': 'first'}).reset_index()

mode_τμ = df_avg_below_halfway['day'].mean()
σ_τμ = df_avg_below_halfway['day'].std()

sh_τμ,ra_τμ = gamma_shra_from_modesd(mode_τμ, σ_τμ)

# Priors for μ_μ∞
μ_μμ_inf = mean_last_day['yds'].mean()
σ_μμ_inf = mean_last_day['yds'].std()


In [32]:
print(f'{μ_Aμμ=}')
print(f'{σ_Aμμ=}')

print(f'{mode_τμ=}')
print(f'{σ_τμ=}')

print(f'{μ_μμ_inf=}')
print(f'{σ_μμ_inf=}')


μ_Aμμ=-6.501537612418634
σ_Aμμ=24.403456999725552
mode_τμ=1.84
σ_τμ=2.1150256105620406
μ_μμ_inf=3.451418788450218
σ_μμ_inf=17.506794941415126


#### σ_μ hyper priors

In [33]:
# Priors for Aσμ
σ_Aσμ = merged_data['difference'].std()

# Priors for σ_μ∞
σ_σμ_inf = mean_last_day['yds'].std()


In [34]:
print(f'{σ_Aσμ=}')

print(f'{μ_τμ=}')
print(f'{σ_τμ=}')

print(f'{σ_σμ_inf=}')

σ_Aσμ=24.403456999725552
μ_τμ=2.36
σ_τμ=2.1150256105620406
σ_σμ_inf=17.506794941415126


#### μ_σ hyper priors

In [35]:
# Priors for Aμσ
df_std = df.groupby(['fish', 'stage', 'day']).agg({'yds': 'std'}).reset_index()
df_std.rename(columns={'yds': 'std_yds'}, inplace=True)

std_first_day = df_std[df_std['day'] == first_day_value]
last_day_indices = df_std.groupby(['fish', 'stage'])['day'].idxmax()
std_last_day = df_std.loc[last_day_indices]

merged_std = pd.merge(std_first_day, std_last_day, on=['fish', 'stage'], suffixes=('_first', '_last'))
merged_std['std_diff'] = merged_std['std_yds_last'] - merged_std['std_yds_first']


σ_Aμσ = merged_std['std_diff'].mean()

# Priors for τ_μσ
merged_std['halfway_std'] = (merged_std['std_yds_first'] + merged_std['std_yds_last']) / 2
df_std = pd.merge(df_std, merged_std[['fish', 'stage', 'halfway_std']], on=['fish', 'stage'])
std_less_than_halfway = df_std[df_std['std_yds'] < df_std['halfway_std']].groupby(['fish', 'stage']).agg({'day': 'min'}).reset_index()

mode_τσ = std_less_than_halfway['day'].mean()
σ_τσ = std_less_than_halfway['day'].std()

sh_τσ,ra_τσ = gamma_shra_from_modesd(mode_τσ, σ_τσ)

# Priors for μ_σ∞
σ_μσ_inf = std_last_day['std_yds'].mean()


In [36]:
print(f'{σ_Aμσ=}')

print(f'{mode_τσ=}')
print(f'{σ_τσ=}')

print(f'{σ_μσ_inf=}')

μ_Aμσ=5.476512591557492
σ_Aμσ=14.083089834923953
mode_τσ=2.92
σ_τσ=5.275098735252893
μ_μσ_inf=30.609773171397105
σ_μσ_inf=10.617711207970974


#### σ_σ hyper priors

In [37]:
# Priors for Aσσ
σ_Aσσ = merged_std['std_diff'].std()

# Priors for σ_μ∞
σ_σσ_inf = std_last_day['std_yds'].std()

In [38]:
print(f'{σ_Aσσ=}')

print(f'{σ_σσ_inf=}')

σ_Aσσ=14.083089834923953
σ_σσ_inf=10.617711207970974


#### Prior for student's t distribution degrees of freedom

In [39]:
μ_ν = 10

### PyMC model code

#### Make codings for fish and stage

In [40]:
fish_index_map = {fish: index for index, fish in enumerate(fish_list)}
stage_index_map = {stage: index for index, stage in enumerate(stage_list)}

# Replace fish numbers with their corresponding indices
df['fish_index'] = df['fish'].map(fish_index_map)
df['stage_index'] = df['stage'].map(stage_index_map)
df['day_index'] = df['day'] - first_day_value

#### Model code

In [41]:
coords = {
    "fish": fish_list,
    "stages": stage_list,
    "data": np.arange(num_data),
}

with pm.Model(coords=coords) as m_yds:
    ### Constants
    fish_ = pm.ConstantData('fish_', df['fish_index'], dims='data')
    stage_ = pm.ConstantData('stage_', df['stage_index'], dims='data')
    day_ = pm.ConstantData('day_', df['day_index'], dims='data')

    ### level 4: Priors for μ_μ, μ_σ, σ_μ, σ_σ
    A_μμ = pm.Normal('A_μμ', mu=μ_Aμμ, sigma=σ_Aμμ, dims=('fish', 'stages'))
    τ_μ = pm.Gamma('τ_μ', alpha=sh_τμ, beta=ra_τμ, dims=('fish', 'stages'))
    μ_μ_inf = pm.Normal('μ_μ_inf', mu=μ_μμ_inf, sigma=σ_μμ_inf, dims=('fish', 'stages'))

    A_σμ = pm.HalfNormal('A_σμ', sigma=σ_Aσμ, dims=('fish', 'stages'))
    σ_μ_inf = pm.HalfNormal('σ_μ_inf', sigma=σ_σμ_inf, dims=('fish', 'stages'))

    A_mσ = pm.HalfNormal('A_mσ', sigma=σ_Amσ, dims=('fish', 'stages'))
    τ_σ = pm.Gamma('τ_σ', alpha=sh_τσ, beta=ra_τσ, dims=('fish', 'stages'))
    m_σ_inf = pm.HalfNormal('m_σ_inf', sigma=σ_mσ_inf, dims=('fish', 'stages'))

    A_σσ = pm.HalfNormal('A_σσ', sigma=σ_Aμμ, dims=('fish', 'stages'))
    σ_σ_inf = pm.HalfNormal('σ_σ_inf', sigma=σ_σσ_inf, dims=('fish', 'stages'))

    ### level 3: Regression equations for μ and σ
    μ_μ = pm.Deterministic('μ_μ', A_μμ[fish_,stage_]*pm.math.exp(-day_ / τ_μ[fish_,stage_]) + μ_μ_inf[fish_,stage_], dims='data' )
    σ_μ = pm.Deterministic('σ_μ', A_σμ[fish_,stage_]*pm.math.exp(-day_ / τ_μ[fish_,stage_]) + σ_μ_inf[fish_,stage_], dims='data')
    mode_σ = pm.Deterministic('mode_σ', A_mσ[fish_,stage_]*pm.math.exp(-day_ / τ_σ[fish_,stage_]) + m_σ_inf[fish_,stage_], dims='data')
    σ_σ = pm.Deterministic('σ_σ', A_σσ[fish_,stage_]*pm.math.exp(-day_ / τ_σ[fish_,stage_]) + σ_σ_inf[fish_,stage_], dims='data')
    sh_σ,ra_σ = gamma_shra_from_modesd(mode_σ,σ_σ)

    ### level 2: Priors for the likelihood
    μ = pm.Normal('μ', mu=μ_μ, sigma=σ_μ, dims='data')    
    σ = pm.Gamma('σ', alpha=sh_σ, beta=ra_σ, dims='data')
    ν = pm.Exponential('ν', lam=1/μ_ν)
    
    ### data likelihood level 1
    y = pm.StudentT('y', mu=μ, sigma=σ, nu=ν, observed=df['yds'], dims='data')


## Sample prior predictive

In [42]:
id_yds = pm.sample_prior_predictive(samples=100, model=m_yds)

Sampling: [A_μμ, A_μσ, A_σμ, A_σσ, y, μ, μ_μ_inf, μ_σ_inf, ν, σ, σ_μ_inf, σ_σ_inf, τ_μ, τ_σ]


In [46]:
id_yds

In [43]:
_, ax = plt.subplots(len(fish_list), len(stage_list))

for f in fish_list:
    for s in stage_list:
        x = xr.DataArray(np.linspace(, 2, 50), dims=["plot_dim"])
        prior = idata.prior
        y = prior["a"] + prior["b"] * x

        ax.plot(x, y.stack(sample=("chain", "draw")), c="k", alpha=0.4)

        ax.set_xlabel("Predictor (stdz)")
        ax.set_ylabel("Mean Outcome (stdz)")
        ax.set_title("Prior predictive checks -- Flat priors");

In [None]:
%load_ext watermark
%watermark -v -iv -w