# FOV model
## renew on 2024-4-1

April Fool's Day Celebration Codes

## Setup

In [6]:
import arviz as az
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np
import pandas as pd
import seaborn as sns
import pymc as pm
import pytensor.tensor as pt
import random
az.style.use("arviz-grayscale")

### Load data

In [None]:
Allfishtrial = pd.read_csv('tips.csv')
Allfishtrial.tail()

In [None]:
BRWR_fishid = pd.Categorical(Allfishtrial['fish_id'],
                     categories=['fish3', 'fish7', 'fish9', 'fish10']).codes
BRWR_sessiontype = pd.Categorical(Allfishtrial['session_type'],
                     categories=['Baseliine', 'Rotation', 'Washout', 'ReRotation']).codes


groups_fishid = len(np.unique(BRWR_fishid))


Priors calculation， should get these from raw data

In [3]:
fish_list = [4, 6, 7, 8, 10]
stage_list = ['baseline', 'rotation', 'washout', 'savings', 'washout 2']
days = {'baseline': 27, 'rotation': 16, 'washout': 10, 'savings': 16, 'washout 2': 10}

# Initialize an empty list to store the data
data = []

# Generate data for each fish, epoch, day, and trial
for f in fish_list:
    for stage in stage_list:
        for day in range(1, days[stage] + 1):
            num_trials = np.random.randint(4, 8)  # Randomly choose number of trials per day
            for trial in range(1, num_trials + 1):
                yds = np.random.normal(loc=0, scale=30)  # Generate random yds
                data.append([f, stage, day, trial, yds])

# Create a pandas DataFrame
df = pd.DataFrame(data, columns=['fish', 'stage', 'day', 'trial', 'yds'])

num_data = df.shape[0]

In [4]:
df.head()

Unnamed: 0,fish,stage,day,trial,yds
0,4,baseline,1,1,-13.1831
1,4,baseline,1,2,-16.475159
2,4,baseline,1,3,4.236955
3,4,baseline,1,4,-22.692728
4,4,baseline,1,5,-1.430846


In [5]:
num_data

2166

## Define model

## Priors

### μ_μ hyper priors

In [None]:
# Priors for Aμμ
df_avg = df.groupby(['fish', 'stage', 'day']).agg({'yds': 'mean'}).reset_index()

mean_first_day = df_avg[df_avg['day'] == 0]
max_day_indices = df_avg.groupby(['fish', 'stage'])['day'].idxmax()
mean_last_day = df_avg.loc[max_day_indices]

merged_data = pd.merge(mean_first_day, mean_last_day, on=['fish', 'stage'], suffixes=('_first', '_last'))
merged_data['difference'] = merged_data['yds_first'] - merged_data['yds_last']

μ_Aμμ = merged_data['difference'].mean()
σ_Aμμ = merged_data['difference'].std()


# Priors for τ_μμ
merged_data['halfway_yds'] = (merged_data['yds_first'] + merged_data['yds_last']) / 2
df_avg_below_halfway = df_avg[df_avg['yds'] < merged_data['halfway_yds']].groupby(['fish', 'stage']).agg({'day': 'first'}).reset_index()

μ_τμμ = df_avg_below_halfway['day'].mean()
σ_τμμ = df_avg_below_halfway['day'].std()


# Priors for μ_μ∞
μ_μμ∞ = last_day_yds['yds'].mean()
σ_μμ∞ = last_day_yds['yds'].std()


### σ_μ hyper priors

In [None]:
# Priors for Aσμ
σ_Aσμ = merged_data['difference'].std()


# Priors for τ_σμ
# Assume timing is similar to the changes in the mean, we can use the saame priors
μ_τσμ = μ_τμμ
σ_τσμ = σ_τμμ


# Priors for σ_μ∞
σ_σμ∞ = std_last_day_yds


### μ_σ hyper priors

In [None]:
# Priors for Aμσ
df_std = df.groupby(['fish', 'stage', 'day']).agg({'yds': 'std'}).reset_index()
df_std.rename(columns={'yds': 'std_yds'}, inplace=True)

std_first_day = df_std[df_std['day'] == 0]
last_day_indices = df_std.groupby(['fish', 'stage'])['day'].idmax()
std_last_day = df_std.loc[last_day_indices]

merged_std = pd.merge(std_first_day, std_last_day, on=['fish', 'stage'], suffixes=('_first', '_last'))
merged_std['std_diff'] = merged_std['std_yds_last'] - merged_std['std_yds_first']

μ_Aμσ = merged_std['std_diff'].mean()
σ_Aμσ = merged_std['std_diff'].std()


# Priors for τ_μσ
merged_std['halfway_std'] = (merged_std['std_yds_first'] + merged_std['std_yds_last']) / 2
std_less_than_halfway = df_std[df_std['std_yds'] < df_std['halfway_std']].groupby(['fish', 'stage']).agg({'day': 'min'}).reset_index()

μ_τμσ = std_less_than_halfway['day'].mean()
σ_τμσ = std_less_than_halfway['day'].std()


# Priors for μ_σ∞
μ_μσ∞ = std_last_day.mean()
σ_μσ∞ = std_last_day.std()

### σ_σ hyper priors

In [None]:
# Priors for Aσσ
σ_Aσσ = merged_std['std_diff'].std()


# Priors for τ_σσ
μ_τσμ = μ_τμσ
σ_τσμ = μ_τμσ


# Priors for σ_μ∞
σ_σμ∞ = std_last_day_yds

In [None]:
coords = {
    "fish": fish_list,
    "stages": stage_list,
    "data": np.arange(num_data),
}

with pm.Model(coords=coords) as FovMV:
### level 4: Priors for μ_μ, μ_σ, σ_μ, σ_σ
    A_μμ = pm.Normal('A_μμ', mu=μ_Aμμ, sigma=σ_Aμμ)
    τ_μμ = pm.Gamma('τ_μμ', mu=μ_τμμ, sigma=σ_τμμ)
    μ_μ∞ = pm.Normal('A_μμ', mu=μ_μμ∞, sigma=σ_μμ∞)

    A_σμ = pm.HalfNormal('A_μμ', sigma=σ_Aσμ)
    τ_σμ = pm.Gamma('τ_μμ', mu=μ_τσμ, sigma=σ_τσμ)
    σ_μ∞ = pm.HalfNormal('A_μμ', sigma=σ_σμ∞)

    A_μσ = pm.Normal('A_μμ', mu=μ_Aμμ, sigma=σ_Aμμ)
    τ_μσ = pm.Gamma('τ_μμ', mu=μ_τμμ, sigma=σ_τμμ)
    μ_σ∞ = pm.Normal('A_μμ', mu=μ_μσ∞, sigma=σ_μσ∞)

    A_σσ = pm.HalfNormal('A_μμ', sigma=σ_Aμμ)
    τ_σσ = pm.Gamma('τ_μμ', mu=μ_τμμ, sigma=σ_τμμ)
    σ_σ∞ = pm.HalfNormal('A_μμ', sigma=σ_σσ∞)


### level 3: Regression equations for μ and σ
    μ_μ = pm.Deterministic('μ_μ', A_μμ*pm.math.exp(-day_ / τ_μμ) + μ_μ∞ , shape=(fish, stages))
    σ_μ = pm.Deterministic('σ_μ', A_σμ*pm.math.exp(-day_ / τ_σμ) + σ_μ∞ shape=(fish, stages))
    μ_σ = pm.Deterministic('μ_σ', A_μσ*pm.math.exp(-day_ / τ_μσ) + μ_σ∞, shape=(fish, stages))
    σ_σ = pm.Deterministic('σ_σ', A_σσ*pm.math.exp(-day_ / τ_σσ) + σ_σ∞ shape=(fish, stages))

### level 2: Priors for the likelihood
    μ = pm.Normal('μ', mu=μ_μ[fish_,stage_], sigma=σ_μ[fish_,stage_], shape=data)    
    σ = pm.Gamma('σ', mu=μ_σ[fish_,stage_], sigma=σ_σ[fish_,stage_], shape=data)
    ν = pm.Exponential('ν', lam=1/μ_νy)
    
### data likelihood level 1
    y = pm.StudentT('y', mu=μ, sigma=σ, nu=ν, observed=Allfishtrial, shape=data)


## Set up all the conditions

## Generate samples

In [None]:
%load_ext watermark
%watermark -v -iv -w