In [1]:
import sys
from pathlib import Path
import os
sys.path.append(str(Path(os.getcwd()).parent))

from typing import NewType, List, Dict, Tuple, Optional, Generic, TypeVar, Union, Any
import numpy as np
import pandas as pd
import xarray as xr
from active_inference.types.parameter import Likelihood, Transition
from active_inference.types.variable import Observation, Action, State
from active_inference.interface_adapters.repository import create_dataset
from simulation.init_parameters import InitA, InitB_f, InitB_nf, InitC, InitD, InitG, InitS, InitO, InitU, Initu, Inite
from simulation.configs import TimeSteps

# Parameters
#================================================
# Likelihood (A)
#-------------------------------------------------
init_A: InitA = InitA()
likelihood: Likelihood = Likelihood.create(init_A.data, init_A.coords, init_A.dims)
A: xr.DataArray = likelihood.to_xarray()

# Transition probability (B)
#-------------------------------------------------
# - B_freezing
init_B_f: InitB_f = InitB_f()
transition_f: Transition = Transition.create(init_B_f.data, init_B_f.coords, init_B_f.dims)
B_f: xr.DataArray = transition_f.to_xarray()

# - B_non_freezing 
init_B_nf: InitB_nf = InitB_nf()
transition_nf: Transition = Transition.create(init_B_nf.data, init_B_nf.coords, init_B_nf.dims)
B_nf: xr.DataArray = transition_nf.to_xarray()

# - B: TODO: 専用のクラスを作る
B: xr.DataArray = xr.concat(
    [B_f, B_nf],
    dim = pd.Index(['freezing', 'non_freezing'], name='action')
)

# Prior for Observations (C)
#-------------------------------------------------
init_C: InitC = InitC()
prior_observation: Observation = Observation.create(init_C.data, init_C.coords, init_C.dims)
C: xr.DataArray = prior_observation.to_xarray()

# Initial state probability (D)
#-------------------------------------------------
init_D: InitD = InitD()
initial_state: State = State.create(init_D.data, init_D.coords)
D: xr.DataArray = initial_state.to_xarray()

# Expected Free Energy (G)
#-------------------------------------------------
init_G: InitG = InitG()
expected_free_energy: Action = Action.create(init_G.data, init_G.coords, init_G.dims)
G: xr.DataArray = expected_free_energy.to_xarray()

# Variables
#================================================
# State (S)
#-------------------------------------------------
init_S: InitS = InitS()
state: State = State.create(init_S.data, init_S.coords)
S: xr.DataArray = state.to_xarray()

# Observation (O)
#-------------------------------------------------
init_O: InitO = InitO()
observation: Observation = Observation.create(init_O.data, init_O.coords, init_O.dims)
O: xr.DataArray = observation.to_xarray()

# Action (U)
#-------------------------------------------------
# - Probability of action (U)
init_U: InitU = InitU()
action: Action = Action.create(init_U.data, init_U.coords, init_U.dims)
U: xr.DataArray = action.to_xarray()

# - Index vector of action (u)
init_u: Initu = Initu()
action_index: Action = Action.create(init_u.data, init_u.coords, init_u.dims)  # TODO: vector用のクラスを作る
u: xr.DataArray = action_index.to_xarray()


# state_prediction error
#-------------------------------------------------
init_e: Inite = Inite()
state_prediction_error: State = State.create(init_e.data, init_e.coords)
e: xr.DataArray = state_prediction_error.to_xarray()

# Create Dataset with A, B, C, D atrices
#================================================
data_vars: Dict[str, xr.DataArray] = {
    'A': A,
    'B': B,
    'C': C,
    'D': D,
    'G': G,
    'S': S,
    'O': O,
    'U': U,
    'u': u,
    'e': e
}

total_dataset: xr.Dataset = create_dataset(data_vars, TimeSteps.trials, TimeSteps.blocks)
total_dataset

In [2]:
from scipy.stats import dirichlet

def optimal_bayes(
    block_index: int,
    practice_blocks: int = 10,
) -> float:
    """最適ベイズによるLSS（Learning Stability Score）の更新
    
    P(shock|sound) = P(shock) * P(sound|shock) / P(sound)
                   = Nl/NT * 1 / 1 = Nl/NT
    
    Args:
        block_index: 現在のブロック番号
        practice_blocks: 学習ブロック数
    
    Returns:
        float: 更新されたLSS値
    """
    total_blocks = practice_blocks + block_index
    shock_blocks = practice_blocks
    return shock_blocks / total_blocks

def variational_bayes(
    block_index: int,
    base_lss: np.array = np.array([0.9, 0.1]),
    memory: int = 1,
    sensitivity: float = 0.15
) -> float:
    """変分ベイズによるLSS（Learning Stability Score）の更新
    
    P(shock|sound) = P(shock) + sensitivity * Σ(Ot)
    
    Args:
        current_lss: 現在のLSS値
        observations: 過去の観測データ
        memory: 考慮する過去の観測数
        sensitivity: 学習感度
    
    Returns:
        float: 更新されたLSS値
    """
    sum_obs = np.array([0, 1])
    param = base_lss + sensitivity * sum_obs * (block_index + 1)
    return dirichlet(param).rvs()[0, 0]

In [3]:
#from tqdm.notebook import trange
from simulation.generative_process import create_observations, generative_process
from active_inference.functions.learning import learning, learning_with_updates
from active_inference.functions.inference import perceptual_inference, active_inference, perceptual_inference_with_optgenetics, perceptual_inference_with_updates, active_inference_with_updates
from active_inference.functions.operators import ln
from simulation.configs import Optgenetics, TimeSteps

# Workflow
#-------------------------------------------------
def run_simulation_step(ds: xr.Dataset, time_index: int) -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset]:
    #print(f"\n=== Running simulation step for trial {time_index} ===")

    # perceptual inference
    #-------------------------------------------------
    ds_after_PI = ds.copy(deep=True)
    #new_state: xr.DataArray = perceptual_inference(ds, time_index)
    new_state, prediction_error = perceptual_inference_with_optgenetics(ds_after_PI, time_index)
    ds_after_PI.S.loc[dict(trial=time_index)] = new_state
    ds_after_PI.e.loc[dict(trial=time_index)] = prediction_error

    # active inference
    #-------------------------------------------------
    # learning
    ds_after_L = ds_after_PI.copy(deep=True)
    updated_params: xr.Dataset = learning(ds_after_L, time_index, params=["A", "C"])

    if time_index < TimeSteps.trials-1:
        ds_after_L.C.loc[dict(trial=time_index+1)] = updated_params.C
        #ds.A.loc[dict(trial=time_index+1)] = updated_params["A"]

    # planning
    ds_after_AI = ds_after_L.copy(deep=True)
    new_policy, new_action, new_efe = active_inference(ds_after_AI, time_index)
    ds_after_AI.U.loc[dict(trial=time_index)] = new_policy
    ds_after_AI.u.loc[dict(trial=time_index)] = new_action
    ds_after_AI.G.loc[dict(trial=time_index)] = new_efe
    
    # return variables
    #-------------------------------------------------
    updated_variables: xr.Dataset = xr.Dataset(
        data_vars = {
            'S': new_state,
            'e': prediction_error,
            'U': new_policy,
            'u': new_action,
            'G': new_efe
        }
    )

    return ds_after_AI, updated_variables, updated_params


def run_simulation_step2(ds: xr.Dataset, time_index: int) -> xr.Dataset:
    print(f"\n=== Running simulation step for trial {time_index} ===")

    # perceptual inference
    #-------------------------------------------------
    #new_state: xr.DataArray = perceptual_inference(ds, time_index)
    ds_PI = perceptual_inference_with_updates(ds, time_index, method = perceptual_inference_with_optgenetics)

    # active inference
    #-------------------------------------------------
    # learning
    ds_L = learning_with_updates(ds_PI, time_index, params=["C"])

    # planning
    ds_AI = active_inference_with_updates(ds_L, time_index)
    
    return ds_AI


# Simulation
#================================================
# Dataset
#-------------------------------------------------
result_ds: xr.Dataset = total_dataset.copy(deep=True)

# Simulation: Continuous blocks case
#-------------------------------------------------
for block in range(0, TimeSteps.blocks):
    print(f"\n=== Running simulation for block {block} ===")

    # filter dataset
    #-------------------------------------------------
    ds: xr.Dataset = result_ds.copy(deep=True).sel(block=block)
    
    # create observation
    #-------------------------------------------------
    #LSS = 0.9
    #LSS = optimal_bayes(block) # optimal bayes
    LSS = variational_bayes(block)
    observations = create_observations(LSS=LSS)
    ds.O.loc[dict(trial=slice(0, TimeSteps.trials))] = observations

    # load initial parameters
    #-------------------------------------------------
    if block > 0:
        #ds.A.loc[dict(trial=0)] = next_a0
        ds.C.loc[dict(trial=0)] = next_c0
        ds.D.loc[dict(trial=0)] = next_d0
    
    # run simulation
    #-------------------------------------------------
    for trial in range(0, TimeSteps.trials):
        # generative process
        updated_observation = generative_process(ds, trial)
        ds.O.loc[dict(trial=trial)] = updated_observation

        # generative model
        new_ds, updated_variables, updated_params = run_simulation_step(ds, trial)

        # update variables
        ds.S.loc[dict(trial=trial)] = updated_variables.S
        ds.e.loc[dict(trial=trial)] = updated_variables.e
        ds.U.loc[dict(trial=trial)] = updated_variables.U
        ds.u.loc[dict(trial=trial)] = updated_variables.u
        ds.G.loc[dict(trial=trial)] = updated_variables.G
        # update parameters
        if trial < TimeSteps.trials-1:
            ds.C.loc[dict(trial=trial+1)] = updated_params.C
            #ds.A.loc[dict(trial=trial+1)] = updated_params.A

    # update result dataset
    #-------------------------------------------------
    result_ds.loc[dict(block=block)] = ds

    if block < TimeSteps.blocks-1:
        result_ds.loc[dict(block=block+1)] = ds
    
    # store last parameters
    #-------------------------------------------------
    #next_a0 = updated_params["A"]
    # update C
    next_c0: xr.DataArray = updated_params.C
    # update D
    current_transition: xr.DataArray = ds.B.isel(trial = trial).isel(action = updated_variables.u.argmax().item())
    current_state: xr.DataArray = updated_variables.S
    next_d0: xr.DataArray = np.dot(ln(current_transition), current_state)

#forward_message = message_forward(blanket)
result_ds






=== Running simulation for block 0 ===
activaton =  <xarray.DataArray 'S' (state: 2)>
array([1.0182297, 2.0182297])
Coordinates:
  * state    (state) <U6 'fear' 'relief'
    trial    int64 0
    block    int64 0
    time     int64 0
activaton =  <xarray.DataArray 'S' (state: 2)>
array([-1.1646331 ,  0.77736601])
Coordinates:
  * state    (state) <U6 'fear' 'relief'
    trial    int64 1
    block    int64 0
    time     int64 1
activaton =  <xarray.DataArray 'S' (state: 2)>
array([-8.5952051 ,  0.63725117])
Coordinates:
  * state    (state) <U6 'fear' 'relief'
    trial    int64 2
    block    int64 0
    time     int64 2
activaton =  <xarray.DataArray 'S' (state: 2)>
array([-14.64122604,  -0.25272495])
Coordinates:
  * state    (state) <U6 'fear' 'relief'
    trial    int64 3
    block    int64 0
    time     int64 3
activaton =  <xarray.DataArray 'S' (state: 2)>
array([-14.12116063,  -0.4520842 ])
Coordinates:
  * state    (state) <U6 'fear' 'relief'
    trial    int64 4
    block   

# plot

In [4]:
import hvplot.xarray
import holoviews as hv
from holoviews import opts
from bokeh.io import output_notebook
output_notebook()

from active_inference.interface_adapters.plotter import probability_plot
from active_inference.interface_adapters.stastics import trial_mean, trial_stderr, block_mean

# dataset
#-------------------------------------------------
plot_ds = result_ds
init_ds = result_ds.isel(block=slice(0, 4))
last_ds = result_ds.isel(block=slice(TimeSteps.blocks-5, TimeSteps.blocks-1))
cs_ds = result_ds.isel(trial=slice(80,120))
post_cs_ds = result_ds.isel(trial=slice(120,160))

# 各プロットを作成
#-------------------------------------------------
plot_O_init = probability_plot(init_ds.O).opts(title="Observation (Initial blocks)")
plot_O_last = probability_plot(last_ds.O).opts(title="Observation (Last blocks)")

# レイアウトを作成
hv.Layout([plot_O_init, plot_O_last]).cols(1)

In [5]:
plot_S_init = probability_plot(init_ds.S)
plot_S_last = probability_plot(last_ds.S)
hv.Layout([plot_S_init, plot_S_last]).cols(1)


In [6]:
plot_e_init = probability_plot(init_ds.e).opts(title="Prediction Error (Initial blocks)")
plot_e_last = probability_plot(last_ds.e).opts(title="Prediction Error (Last blocks)")
hv.Layout([plot_e_init, plot_e_last]).cols(1)

# 時系列なしでもいいのでは？
# フリージングの％のみ。
# 報酬予測誤差のモデルを載せて，比較。≠－＋

In [7]:
def mean_probability_plot(block_means: xr.DataArray) -> hv.Overlay:
    """Create a plot with lines and confidence intervals for 2D data"""
    
    plot_variable: str = block_means.dims[-1]
    plot_component: str = str(block_means.coords[plot_variable].values)

    if len(block_means.shape) > 1:
        line_plots: hv.Overlay = block_means.hvplot.line(
            by = plot_variable,
            label = plot_component,
        )
    else:
        line_plots: hv.Overlay = block_means.hvplot.line(
            label=str(block_means.coords[block_means.dims[-1]].values),
        )
    
    return line_plots

plot_marginal_e_init = mean_probability_plot(block_mean(init_ds.e.sum(dim='state'))).opts(title="Marginal Prediction Error (Initial blocks)")
plot_marginal_e_last = mean_probability_plot(block_mean(last_ds.e.sum(dim='state'))).opts(title="Marginal Prediction Error (Last blocks)")
hv.Layout([plot_marginal_e_init, plot_marginal_e_last]).cols(1)



In [8]:
S = np.array([1.0, 0.0])
B = np.array([[0.9, 0.9], [0.1, 0.1]])
print(np.dot(B, S))



[0.9 0.1]


In [9]:
plot_difference_e_init = mean_probability_plot(-block_mean(init_ds.e.diff(dim='state'))).opts(title="Difference Prediction Error (Initial blocks)")
plot_difference_e_last = mean_probability_plot(-block_mean(last_ds.e.diff(dim='state'))).opts(title="Prediction Error (Last blocks)")
hv.Layout([plot_difference_e_init, plot_difference_e_last]).cols(1)

In [10]:

plot_u_init = (probability_plot(init_ds.u.rolling(trial=5).mean())
                .opts(title="Freezing behavior (Initial blocks)")
)
plot_u_last = (probability_plot(last_ds.u.rolling(trial=5).mean())
                .opts(title="Freezing behavior (Last blocks)")
)
hv.Layout([plot_u_init, plot_u_last]).cols(1)

# レーザー照射の条件を確認
#-------------------------------------------------

In [11]:
plot_C_init = probability_plot(init_ds.C).opts(title="Sensitivity (Initial blocks)")
plot_C_last = probability_plot(last_ds.C).opts(title="Sensitivity (Last blocks)")
hv.Layout([plot_C_init, plot_C_last]).cols(1)


In [12]:
plot_G_init = probability_plot(init_ds.G).opts(title="Expected Free Energy (Initial blocks)")
plot_G_last = probability_plot(last_ds.G).opts(title="Expected Free Energy (Last blocks)")
hv.Layout([plot_G_init, plot_G_last]).cols(1)

In [13]:
plot_U_init = probability_plot(init_ds.U).opts(title="Freezing policy (Initial blocks)")
plot_U_last = probability_plot(last_ds.U).opts(title="Freezing policy (Last blocks)")
hv.Layout([plot_U_init, plot_U_last]).cols(1)

In [14]:
# ブロック平均の計算
block_averaged = (cs_ds.u.sel(action='freezing')  # freezingのデータを選択
                  .pipe(trial_mean)  # 試行方向の平均を計算
                  .rolling(block=2).mean()  # ブロック方向の移動平均を計算
                  .isel(block=slice(1, None, 2))  # 奇数番号のデータのみ選択
)

block_stderrs = (cs_ds.u.sel(action='freezing')  # freezingのデータを選択
                  .pipe(trial_stderr)  # 試行方向の平均を計算
                  .rolling(block=2).mean()  # ブロック方向の移動平均を計算
                  .isel(block=slice(1, None, 2))  # 奇数番号のデータのみ選択
)

# プロット
plot = block_averaged.hvplot.line(
    xticks=np.arange(1, len(block_averaged.block)*2, 2),  # 2刻みの目盛り
    title="Mean of freezing rate (CS blocks)"
)
plot

# 信頼区間の計算
time_steps = block_averaged.block.values
lower_bound = block_averaged - block_stderrs
upper_bound = block_averaged + block_stderrs

# 標準誤差を考慮したエリアプロット
stderr_area = hv.Area(
    (time_steps, lower_bound, upper_bound),
    vdims=['lower_bound', 'upper_bound']
).opts(alpha=0.2, line_alpha=0)

# 移動平均プロットと標準誤差エリアを重ねる
final_plot = plot * stderr_area

# プロットを表示
final_plot



In [15]:
shifted_ds = post_cs_ds.copy(deep=True)
shifted_u = shifted_ds.u.shift(trial=+1)
shifted_ds['u'] = shifted_u

e_by_freezing = (shifted_ds.e
.where(shifted_ds.u.sel(action='freezing') == 1)
.sum(dim='state')
.mean(dim=['trial'])
.rolling(block=5).mean()
)

# not_freezingが1のときのeの平均を計算 
marginal_e_by_not_freezing = (shifted_ds.e
.where(shifted_ds.u.sel(action='non_freezing') == 1)
.sum(dim='state')
.mean(dim=['trial'])
.rolling(block=5).mean()
)


plot_e_by_freezing = e_by_freezing.hvplot.line(
    xticks=np.arange(4, len(e_by_freezing.block)*5, 5),  # 2刻みの目盛り
    title="Mean of freezing rate (post-CS blocks)",
    label="Freezing"
)
plot_e_by_not_freezing = marginal_e_by_not_freezing.hvplot.line(
    xticks=np.arange(4, len(marginal_e_by_not_freezing.block)*5, 5),  # 2刻みの目盛り
    title="Mean of freezing rate (post-CS blocks)",
    label="Not Freezing"
)

(plot_e_by_freezing * plot_e_by_not_freezing).opts(
    title="Marginal Prediction Error (post-CS blocks)",
    legend_position="right"
)

In [16]:
from active_inference.functions.operators import to_one_hot_vector

e_by_freezing = (post_cs_ds.e
.where(np.eye(2)[np.argmax(post_cs_ds.S.values, axis=-1)][0] == 1)
.sum(dim='state')
.mean(dim=['trial']) 
.rolling(block=5).mean()
)

# not_freezingが1のときのeの平均を計算 
marginal_e_by_not_freezing = (post_cs_ds.e
.where(np.eye(2)[np.argmax(post_cs_ds.S.values, axis=-1)][1] == 1)
.sum(dim='state')
.mean(dim=['trial'])
.rolling(block=5).mean()
)


plot_e_by_freezing = e_by_freezing.hvplot.line(
    xticks=np.arange(4, len(e_by_freezing.block)*5, 5),  # 2刻みの目盛り
    title="Mean of freezing rate (post-CS blocks)",
    label="Freezing"
)
plot_e_by_not_freezing = marginal_e_by_not_freezing.hvplot.line(
    xticks=np.arange(4, len(marginal_e_by_not_freezing.block)*5, 5),  # 2刻みの目盛り
    title="Mean of freezing rate (post-CS blocks)",
    label="Not Freezing"
)

(plot_e_by_freezing * plot_e_by_not_freezing).opts(
    title="Marginal Prediction Error (post-CS blocks)",
    legend_position="right"
)

#print(np.eye(2)[np.argmax(post_cs_ds.S.values, axis=-1)][0])

In [17]:
from active_inference.interface_adapters.stastics import block_stderr


def mean_probability_plot(block_means: xr.DataArray) -> hv.Overlay:
    """Create a plot with lines and confidence intervals for 2D data"""
    
    plot_variable: str = block_means.dims[-1]
    plot_component: str = str(block_means.coords[plot_variable].values)

    if len(block_means.shape) > 1:
        line_plots: hv.Overlay = block_means.hvplot.line(
            by = plot_variable,
            label = plot_component,
        )
    else:
        line_plots: hv.Overlay = block_means.hvplot.line(
            label = plot_component,
        )
    
    return line_plots

def mean_probability_plot_with_stderrs(block_means: xr.DataArray, block_stderrs: xr.DataArray) -> hv.Overlay:
    """Create a plot with lines and confidence intervals for 2D data"""

    mean_plots: hv.Overlay = mean_probability_plot(block_means)     

    time_steps: np.ndarray = block_means[block_means.dims[0]].values
    
    area_plots_list: list[hv.Area] = []

    if len(block_means.shape) > 1:
        plot_variable: str = block_means.dims[-1]
        plot_component: list[str] = list(block_means.coords[plot_variable].values)
    
        for i, name in enumerate(plot_component):
            area_plots = hv.Area(
                (time_steps,  # x_axis
                block_means.isel({plot_variable: i}) - block_stderrs.isel({plot_variable: i}),  # lower_bound
                block_means.isel({plot_variable: i}) + block_stderrs.isel({plot_variable: i}),  # upper_bound
                ),  # upper_bound
                vdims = ['lower_bound', 'upper_bound']
            ).opts(alpha=0.2, line_alpha=0)
            area_plots_list.append(area_plots)
    else:
        area_plots = hv.Area(
            (time_steps,  # x_axis
            block_means - block_stderrs,  # lower_bound
            block_means + block_stderrs,  # upper_bound
            ),  # upper_bound
            vdims = ['lower_bound', 'upper_bound']
        ).opts(alpha=0.2, line_alpha=0)

        area_plots_list.append(area_plots)

    return mean_plots * hv.Overlay(area_plots_list)


freezing_during_cs = cs_ds.u.rolling(block=2).mean()

freezing_during_cs
#create_probability_plot_revised_2(block_mean(freezing_during_cs), block_stderr(freezing_during_cs))
#
#block_mean(freezing_during_cs).isel({"action": 0}).hvplot(label="freezing")
#list(block_mean(freezing_during_cs).dims)


#freezing_during_cs.mean(['block'])

block_mean_freezing = block_mean(freezing_during_cs.sel(action='freezing'))
block_stderr_freezing = block_stderr(freezing_during_cs.sel(action='freezing'))
#block_mean_freezing = block_mean(freezing_during_cs)
#block_stderr_freezing = block_stderr(freezing_during_cs)

block_mean_freezing.hvplot.line(
    by='action',
    label="Freezing Action",
    line_width=2,
    title="Freezing Action Over Trials"
)

#block_mean_freezing.coords.variables
mean_probability_plot(block_mean_freezing)
mean_probability_plot_with_stderrs(block_mean_freezing, block_stderr_freezing).opts(
    title="Freezing Action Over Trials"
)









In [18]:

diff_e_by_freezing = (-shifted_ds.e
.where(shifted_ds.u.sel(action='freezing') == 1)
.diff(dim='state')
.mean(dim=['trial'])
.rolling(block=5).mean()
)

# not_freezingが1のときのeの平均を計算 
diff_e_by_not_freezing = (-shifted_ds.e
.where(shifted_ds.u.sel(action='non_freezing') == 1)
.diff(dim='state')
.mean(dim=['trial'])
.rolling(block=5).mean()
)

diff_e_by_freezing
#diff_e_by_not_freezing