# Import and preparation

In [7]:
import requests
import pandas as pd
from io import StringIO
import gdown

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [11]:
import numpy as np
import pandas as pd
import plotnine as gg
gg.theme_set(gg.theme_classic)  # for nicer-looking plots
import jax.numpy as jnp
import jax
import optax
import scipy

In [12]:
# !pip install -U dm-haiku
import haiku as hk
rng_seq = hk.PRNGSequence(np.random.randint(2**32))

In [None]:
#@title Install required packages.
try:
    from google.colab import files
    _ON_COLAB = True
except:
    _ON_COLAB = False

if _ON_COLAB:
  !rm -rf CogModelingRNNsTutorial
  !git clone https://github.com/YifeiCAO/CogModelingRNNsTutorial
  !pip install -e CogModelingRNNsTutorial/CogModelingRNNsTutorial
  !cp CogModelingRNNsTutorial/CogModelingRNNsTutorial/*py CogModelingRNNsTutorial
else:
  !pip install CogModelingRNNsTutorial/requirements.txt

Cloning into 'CogModelingRNNsTutorial'...
remote: Enumerating objects: 1237, done.[K
remote: Counting objects: 100% (313/313), done.[K
remote: Compressing objects: 100% (128/128), done.[K
remote: Total 1237 (delta 252), reused 229 (delta 185), pack-reused 924 (from 2)[K
Receiving objects: 100% (1237/1237), 6.67 MiB | 10.31 MiB/s, done.
Resolving deltas: 100% (786/786), done.
Obtaining file:///content/CogModelingRNNsTutorial/CogModelingRNNsTutorial
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: CogModelingRNNsTutorial
  Running setup.py develop for CogModelingRNNsTutorial
Successfully installed CogModelingRNNsTutorial-0.0.0


In [None]:
#@title Imports + defaults settings.
%load_ext autoreload
%autoreload 2
# for reload
# %reload_ext autoreload

# import haiku as hk
# import jax
# import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax
import os
import warnings

# warnings.filterwarnings("ignore")

# try:
#     from google.colab import files
#     _ON_COLAB = True
# except:
#     _ON_COLAB = False

from CogModelingRNNsTutorial import bandits
from CogModelingRNNsTutorial import disrnn
from CogModelingRNNsTutorial import hybrnn
from CogModelingRNNsTutorial import hybconrnn
from CogModelingRNNsTutorial import hybrnn_direct_con
from CogModelingRNNsTutorial import plotting
from CogModelingRNNsTutorial import rat_data
from CogModelingRNNsTutorial import rnn_utils

In [14]:
osf_url = 'https://osf.io/download/xe6yu/'
response = requests.get(osf_url)

# Check the request
if response.status_code == 200:
    # Read as pandas dataframe
    qasim_data = pd.read_csv(StringIO(response.text))
    print('File downloaded and read successfully!')
else:
    print('Failed to download file. Status code:', response.status_code)

qasim_data.head()

# # read data
selected_columns = ['participant', 'trials_gamble', 'gamble', 'prob', 'reward']

qasim = qasim_data[selected_columns]
qasim_filtered = qasim[qasim['trials_gamble'].notna()]
qasim_sorted = qasim_filtered.groupby('participant', group_keys=False).apply(lambda x: x.sort_values('trials_gamble'))
qasim_sorted = qasim_sorted.reset_index(drop=True)
qasim_sorted['participant'] = qasim_sorted.groupby(['participant']).ngroup() + 1
qasim_sorted['action'] = qasim_sorted['gamble']
qasim_sorted

# —— 2) 把缺失的 action 填成 -1 —— #
qasim_sorted['action'] = qasim_sorted['action'].fillna(-1).astype(int)

# —— 3) 如果需要，把 reward 映射到 0/1 —— #
# （如果已经是 0/1 可跳过；否则取消下面注释并调整映射字典）
# qasim_sorted['reward'] = (
#     qasim_sorted['reward']
#     .map({-1: 0, 1: 1})
#     .fillna(-1)
#     .astype(int)
# )

# —— 4) 排序，确保 trial 顺序 —— #
qasim_sorted = qasim_sorted.sort_values(
    ['participant', 'trials_gamble']
).reset_index(drop=True)

# —— 5) 生成下一步动作 action_n —— #
qasim_sorted['action_n'] = (
    qasim_sorted
    .groupby('participant')['action']
    .shift(-1)
)
# 每个 participant 最后一 trial 的 action_n 设为 -1
last_idxs = qasim_sorted.groupby('participant').tail(1).index
qasim_sorted.loc[last_idxs, 'action_n'] = -1

# —— 6) 按 participant 构造 xs_list, ys_list —— #
xs_list, ys_list = [], []
for pid, grp in qasim_sorted.groupby('participant'):
    grp = grp.sort_values('trials_gamble')
    x = grp[['prob', 'reward']].to_numpy().astype(float)    # 输入特征
    y = grp[['action_n']].to_numpy().astype(int)             # 下一步动作
    xs_list.append(x)
    ys_list.append(y)

# —— 7) 堆成三维数组 —— #
xs_qa = np.stack(xs_list, axis=1)  # (n_sessions, n_trials, 2)
ys_qa = np.stack(ys_list, axis=1)  # (n_sessions, n_trials, 1)

print("xs.shape =", xs_qa.shape)
print("ys.shape =", ys_qa.shape)


File downloaded and read successfully!
xs.shape = (60, 206, 2)
ys.shape = (60, 206, 1)


### Sidarus dataset

In [7]:
# 修改后的 file_id
file_id = '1TSV6CdyClKz831qD2ln4z3WjLLcOgG1n'
download_url = f'https://drive.google.com/uc?id={file_id}'

# 下载并保存为 'downloaded_file.csv'
output_file = 'downloaded_file.csv'
gdown.download(download_url, output_file, quiet=False)

# 读取 CSV 数据
sida_data = pd.read_csv(output_file)
sida_data

# 1) action 从 {1,2} → {0,1}，并把所有缺失值填成 -1
sida_data['action'] = (
    sida_data['action']
    .map({1: 0, 2: 1})    # 映射
    .fillna(-1)           # 缺失值设为 -1
    .astype(int)
)

# 2) outcome 从 {-1,1} → {0,1}，缺失也填成 -1
sida_data['reward'] = (
    sida_data['outcome']
    .map({-1: 0, 1: 1})
    .fillna(-1)
    .astype(int)
)

# 3) 给每个被试×session 分配一个连续的 session_id
sida_data['session_id'] = (
    sida_data
    .groupby(['subj', 'session'], sort=False)
    .ngroup()
    + 1
)

# 4) 排序，保证 trial 顺序
sida_data = sida_data.sort_values(
    ['session_id', 'epN', 'epTrialN']
).reset_index(drop=True)

# 5) 生成下一步动作 action_n；每个 session 末尾设为 -1
sida_data['action_n'] = (
    sida_data
    .groupby('session_id')['action']
    .shift(-1)
)
last_idx = sida_data.groupby('session_id').tail(1).index
sida_data.loc[last_idx, 'action_n'] = -1

# 6) 按 session_id 抽取序列，并堆成 xs, ys
xs_list, ys_list = [], []
for sid, grp in sida_data.groupby('session_id'):
    grp = grp.sort_values(['epN', 'epTrialN'])
    x = grp[['hiRewAct', 'reward']].to_numpy().astype(float)  # 特征
    y = grp[['action_n']].to_numpy().astype(int)              # Label
    xs_list.append(x)
    ys_list.append(y)

# 最终结果：xs.shape == (n_sessions, n_trials, 2)，ys.shape == (n_sessions, n_trials, 1)
xs_sida = np.stack(xs_list, axis=1)
ys_sida = np.stack(ys_list, axis=1)

print("xs.shape =", xs_sida.shape)
print("ys.shape =", ys_sida.shape)


Downloading...
From: https://drive.google.com/uc?id=1TSV6CdyClKz831qD2ln4z3WjLLcOgG1n
To: /content/downloaded_file.csv
100%|██████████| 653k/653k [00:00<00:00, 8.87MB/s]


xs.shape = (800, 40, 2)
ys.shape = (800, 40, 1)


### Schaaf dataset

In [8]:
# 修改后的 file_id
file_id = '1rJFmDhCE3fdXtSHSSvtre_7V49gGsg7P'
download_url = f'https://drive.google.com/uc?id={file_id}'

# 下载并保存为 'downloaded_file.csv'
output_file = 'downloaded_file.csv'
gdown.download(download_url, output_file, quiet=False)

# 读取 CSV 数据
schaaf_data = pd.read_csv(output_file)
schaaf_data

df1 = schaaf_data[['pp', 'trial1', 'response1', 'outcome1']].copy()
df1.columns = ['pp', 'trial_in_session', 'response', 'reward']
df1['session'] = 1

df2 = schaaf_data[['pp', 'trial2', 'response2', 'outcome2']].copy()
df2.columns = ['pp', 'trial_in_session', 'response', 'reward']
df2['session'] = 2

# —— 上面拆分、concat、sort 的部分保持不变 —— #

# 合并成长表
df_long = pd.concat([df1, df2], ignore_index=True)
df_long = df_long.sort_values(['pp', 'session', 'trial_in_session']).reset_index(drop=True)

# —— 映射 outcome，再填 missing response —— #
# 把原来的 response1/response2 改名后是 `response`
# 把 outcome1/outcome2 改名后是 `reward`
# 把 outcome 映射成 1/0，然后把原本缺失的 reward 补成 -1
df_long['reward'] = (
    df_long['reward']
    .map({1: 1, -1: 0})
    .fillna(-1)
    .astype(int)
)

# 把缺失的 response 一样补成 -1
df_long['response'] = (
    df_long['response']
    .fillna(-1)
    .astype(int)
)


# 接下来再做 session_id、shift action_n、以及 stack xs/ys 的流程……
df_long['session_id'] = (df_long['pp'] - 1) * 2 + df_long['session']
df_long['action_n'] = df_long.groupby('session_id')['response'].shift(-1)
last_idx = df_long.groupby('session_id').tail(1).index
df_long.loc[last_idx, 'action_n'] = -1

# 生成 xs, ys
session_ids = df_long['session_id'].unique()
xs_list, ys_list = [], []
for sid in session_ids:
    sd = df_long[df_long['session_id'] == sid].sort_values('trial_in_session')
    x = sd[['response', 'reward']].to_numpy().astype(float)
    y = sd[['action_n']].to_numpy().astype(int)
    xs_list.append(x)
    ys_list.append(y)

xs = np.stack(xs_list, axis=0)   # (n_sessions, n_trials, 2)
ys = np.stack(ys_list, axis=0)   # (n_sessions, n_trials, 1)


# 或者第 0 维是 trials，第 1 维是 sessions：
xs_sch = np.stack(xs_list, axis=1)  # (n_trials, n_sessions, 2)
ys_sch = np.stack(ys_list, axis=1)  # (n_trials, n_sessions, 1)

print("xs.shape =", xs_sch.shape)
print("ys.shape =", ys_sch.shape)


Downloading...
From: https://drive.google.com/uc?id=1rJFmDhCE3fdXtSHSSvtre_7V49gGsg7P
To: /content/downloaded_file.csv
100%|██████████| 374k/374k [00:00<00:00, 6.17MB/s]


xs.shape = (249, 94, 2)
ys.shape = (249, 94, 1)


### Maria Dataset

In [9]:
import pandas as pd
import numpy as np
import gdown

# —— 1) 下载并读入原始数据 —— #
file_id = '1N_zAy-qrbfjvF8Kbb504IH2JNhR5KI-P'
url     = f'https://drive.google.com/uc?id={file_id}'
gdown.download(url, 'downloaded_file.csv', quiet=False)
eck_data = pd.read_csv('downloaded_file.csv')

# —— 2) 筛选＆重命名列 —— #
sel = ['sID','TrialID','selected_box','reward']
eck_sorted = eck_data[sel].copy()
eck_sorted['participant'] = eck_sorted.groupby('sID').ngroup()+1
eck_sorted['action']      = eck_sorted['selected_box']

# 只保留至少做够 120 试次的被试
max_trial = eck_sorted.groupby('participant')['TrialID'].transform('max')
eck_sorted = eck_sorted[max_trial>=120]

# 只用前 120 试次
eck_sorted = eck_sorted[eck_sorted['TrialID']<=120].reset_index(drop=True)

# —— 3) 生成下一步动作 action_n —— #
def generate_action_n(group):
    group = group.sort_values('TrialID')
    group['action_n'] = group['action'].shift(-1).fillna(-1).astype(int)
    return group

eck_sorted = (
    eck_sorted
    .groupby('participant', group_keys=False)
    .apply(generate_action_n)
    .reset_index(drop=True)
)

# —— 4) 按 participant 构造 xs_list/ys_list —— #
xs_list, ys_list = [], []
for pid, grp in eck_sorted.groupby('participant'):
    grp = grp.sort_values('TrialID').iloc[:120]
    x = grp[['action','reward']].to_numpy().astype(float)  # (120,2)
    y = grp[['action_n']].to_numpy().astype(int)           # (120,1)
    xs_list.append(x)
    ys_list.append(y)

# —— 5) stack 成 (n_sessions, n_trials, feat_dim) —— #
xs_ma = np.stack(xs_list, axis=1)  # (305,120,2)
ys_ma = np.stack(ys_list, axis=1)  # (305,120,1)

print("xs_ma.shape =", xs_ma.shape)
print("ys_ma.shape =", ys_ma.shape)


Downloading...
From: https://drive.google.com/uc?id=1N_zAy-qrbfjvF8Kbb504IH2JNhR5KI-P
To: /content/downloaded_file.csv
100%|██████████| 2.57M/2.57M [00:00<00:00, 20.2MB/s]


xs_ma.shape = (120, 305, 2)
ys_ma.shape = (120, 305, 1)


### Generate a big human dataset with all experiments, session length is 130 trial per session

In [10]:
import numpy as np

def segment_and_pad(x: np.ndarray,
                    y: np.ndarray,
                    seg_len: int = 130,
                    pad_x: float = 0.,
                    pad_y: int = -1):
    T, D = x.shape
    n_segs = int(np.ceil(T / seg_len))
    x_segs, y_segs = [], []
    for i in range(n_segs):
        start = i * seg_len
        end = start + seg_len
        x_part = x[start : min(end, T)]
        y_part = y[start : min(end, T)]
        pad = end - min(end, T)
        if pad > 0:
            x_part = np.pad(x_part,
                            pad_width=((0, pad), (0, 0)),
                            constant_values=pad_x)
            y_part = np.pad(y_part,
                            pad_width=((0, pad), (0, 0)),
                            constant_values=pad_y)
        x_segs.append(x_part)
        y_segs.append(y_part)
    return x_segs, y_segs

# —— 假设你已经有这几组 (xs, ys) —— #
# xs_qa   (60,  206, 2), ys_qa  (60,206, 1)
# xs_sida (800, 40,  2), ys_sida(800,40, 1)
# xs_sch  (249, 44,  2), ys_sch (249,44, 1)
# xs_ma   (120,305, 2), ys_ma  (120,305,1)

all_xs, all_ys = [], []

for xs, ys in [(xs_qa, ys_qa),
               (xs_sida, ys_sida),
               (xs_sch, ys_sch),
               (xs_ma, ys_ma)]:

    # 如果是 (n_trials, n_sessions, feat) 维度，就直接：
    # T, N, D = xs.shape
    # 否则若是 (n_sessions, n_trials, feat)，先转：
    # xs = xs.transpose(1,0,2)
    # ys = ys.transpose(1,0,2)

    T, N, D = xs.shape
    for sess in range(N):
        x_seq = xs[:, sess, :]  # (T, D)
        y_seq = ys[:, sess, :]  # (T, 1)
        x_segs, y_segs = segment_and_pad(
            x_seq, y_seq,
            seg_len=130,
            pad_x=0., pad_y=-1
        )
        all_xs.extend(x_segs)
        all_ys.extend(y_segs)

# —— 修改在这里：不要 np.stack，直接输出列表 —— #
# all_xs 是一个 Python list，长度 = 总片段数，每个元素 shape=(130, D)
# all_ys 是一个 Python list，长度 = 总片段数，每个元素 shape=(130, 1)

print(f"Total segments: {len(all_xs)}")
print(f"First xs segment shape: {all_xs[0].shape}")
print(f"First ys segment shape: {all_ys[0].shape}")

# 如果你需要把它们返回成变量：
xs_segment_list = all_xs
ys_segment_list = all_ys


Total segments: 979
First xs segment shape: (130, 2)
First ys segment shape: (130, 1)


In [11]:
import numpy as np

def format_into_datasets_multi_source(
    xs_list: list[np.ndarray],
    ys_list: list[np.ndarray],
    dataset_constructor,
    n_train_sessions: int,
    n_test_sessions: int,
    n_validate_sessions: int,
    batch_size: int = None,
    random_seed: int = None,
):
    """
    按照 QA、SIDA、SCH、MA 这 4 个来源的数据源比例，
    在它们各自内部抽取 train/test/val session，
    最后拼成全局的 DatasetRNN。

    xs_list, ys_list:
      长度 4 的 list，每个元素形状是 (timesteps, n_sessions_i, feat)
    n_*_sessions:
      全局希望 train/test/val 一共要多少 session
    """
    if random_seed is not None:
        rng = np.random.RandomState(random_seed)
    else:
        rng = np.random

    # 1) 计算每个来源各有多少 session
    sess_counts = np.array([xs.shape[1] for xs in xs_list])  # e.g. [206, 40, 44, 305]
    total_sessions = sess_counts.sum()

    # 2) 按比例分配到每个来源的 train/test/val 数目
    def proportional_alloc(total, counts):
        floats = counts / counts.sum() * total
        floors = np.floor(floats).astype(int)
        rem = total - floors.sum()
        # 剩余的按余数最大的那些来源补齐
        remainders = floats - floors
        for idx in np.argsort(remainders)[-rem:]:
            floors[idx] += 1
        return floors

    n_train_per = proportional_alloc(n_train_sessions,    sess_counts)
    n_test_per  = proportional_alloc(n_test_sessions,     sess_counts)
    n_val_per   = proportional_alloc(n_validate_sessions, sess_counts)

    # 3) 在每个来源内部随机打乱并切分
    train_idx_list, test_idx_list, val_idx_list = [], [], []
    for cnt, n_tr, n_te, n_va in zip(sess_counts,
                                     n_train_per,
                                     n_test_per,
                                     n_val_per):
        all_idx = np.arange(cnt)
        rng.shuffle(all_idx)
        train_idx_list.append(all_idx[:n_tr])
        test_idx_list.append( all_idx[n_tr:n_tr+n_te] )
        val_idx_list.append(  all_idx[n_tr+n_te:n_tr+n_te+n_va] )

    # 4) 汇总抽到的 sessions：concat 出全局 xs/ys
    def gather(xs_list, ys_list, idx_lists):
        parts_x, parts_y = [], []
        for xs, ys, idx in zip(xs_list, ys_list, idx_lists):
            # xs: (timesteps, n_sessions_i, feat)
            parts_x.append(xs[:, idx, :])
            parts_y.append(ys[:, idx, :])
        return np.concatenate(parts_x, axis=1), np.concatenate(parts_y, axis=1)

    xs_train, ys_train = gather(xs_list, ys_list, train_idx_list)
    xs_test,  ys_test  = gather(xs_list, ys_list, test_idx_list)
    xs_val,   ys_val   = gather(xs_list, ys_list, val_idx_list)

    # 5) 构造 DatasetRNN
    ds_train = dataset_constructor(xs_train, ys_train, batch_size=batch_size)
    ds_test  = dataset_constructor(xs_test,  ys_test,  batch_size=batch_size)
    ds_val   = dataset_constructor(xs_val,   ys_val,   batch_size=batch_size)

    return ds_train, ds_test, ds_val


In [12]:
# 假设你已经有原始的：
#   xs_qa   (T_qa,   N_qa,   D),   ys_qa   (T_qa,   N_qa,   1)
#   xs_sida (T_sida, N_sida, D),   ys_sida (T_sida, N_sida, 1)
#   xs_sch  (T_sch,  N_sch,  D),   ys_sch  (T_sch,  N_sch,  1)
#   xs_ma   (T_ma,   N_ma,   D),   ys_ma   (T_ma,   N_ma,   1)

def make_segmented_array(xs, ys, seg_len=130, pad_x=0., pad_y=-1):
    all_xs, all_ys = [], []
    T, N, D = xs.shape
    for sess in range(N):
        x_seq = xs[:, sess, :]    # (T, D)
        y_seq = ys[:, sess, :]    # (T, 1)
        x_segs, y_segs = segment_and_pad(x_seq, y_seq, seg_len, pad_x, pad_y)
        all_xs.extend(x_segs)     # list of (130, D)
        all_ys.extend(y_segs)     # list of (130, 1)
    # 把 list 再拼成一个三维 array (130, n_segments, D)
    xs_seg = np.stack(all_xs, axis=1)
    ys_seg = np.stack(all_ys, axis=1)
    return xs_seg, ys_seg

# 针对四个源分别做一次
xs_qa_seg,   ys_qa_seg   = make_segmented_array(xs_qa,   ys_qa)
xs_sida_seg,ys_sida_seg = make_segmented_array(xs_sida, ys_sida)
xs_sch_seg, ys_sch_seg  = make_segmented_array(xs_sch,  ys_sch)
xs_ma_seg,  ys_ma_seg   = make_segmented_array(xs_ma,   ys_ma)

# 然后再把它们送入多源拼分函数
dataset_train, dataset_test, dataset_validate = format_into_datasets_multi_source(
    xs_list   = [xs_qa_seg,   xs_sida_seg,   xs_sch_seg,   xs_ma_seg],
    ys_list   = [ys_qa_seg,   ys_sida_seg,   ys_sch_seg,   ys_ma_seg],
    dataset_constructor = rnn_utils.DatasetRNN,
    n_train_sessions   = 783,
    n_test_sessions    = 98,
    n_validate_sessions= 98,
    batch_size=64,
    random_seed=42,
)


In [13]:
def compute_log_likelihood(dataset, model_fun, params):
    xs, actual_choices = next(dataset)
    n_trials_per_session, n_sessions = actual_choices.shape[:2]
    model_outputs, model_states = rnn_utils.eval_model(model_fun, params, xs)

    # predicted log-probs for the first two actions
    predicted_log_choice_probabilities = np.array(
        jax.nn.log_softmax(model_outputs[:, :, :2], axis=-1)
    )

    n_actions = predicted_log_choice_probabilities.shape[2]
    log_likelihoods = []

    for sess_i in range(n_sessions):
        log_likelihood = 0.0
        n = 0
        for trial_i in range(n_trials_per_session):
            actual_choice = int(actual_choices[trial_i, sess_i])
            # ignore invalid trials (<0 or ≥n_actions)
            if 0 <= actual_choice < n_actions:
                log_likelihood += predicted_log_choice_probabilities[
                    trial_i, sess_i, actual_choice
                ]
                n += 1

        if n > 0:
            normalized_likelihood = np.exp(log_likelihood / n)
            log_likelihoods.append(normalized_likelihood)

    mean_likelihood = np.mean(log_likelihoods)
    std_likelihood  = np.std(log_likelihoods)

    print(f'Average Normalized Likelihood: {100 * mean_likelihood:.1f}%')
    return mean_likelihood, std_likelihood


In [14]:
from google.colab import drive
import pickle

# 挂载 Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Fitting different models

# Fit Vanilla RNN

In [15]:
# 这个相当于是最好的认知模型，在rescolar里面加了一些参数
rl_pc_params, _ = rnn_utils.fit_model(
    model_fun=bandits.Hk_PreserveConAgentQ,
    dataset_train = dataset_train,
    dataset_test = dataset_validate,
    loss_fun='categorical',
    optimizer= optax.chain(
        optax.add_decayed_weights(1e-5),  # L2 正则化
        optax.adam(learning_rate=1e-4)  # Adam 优化器
    ),
    n_steps_per_call=200,
    n_steps_max=100000,
    #return_all_losses=True,
    early_stop_step=200,
    if_early_stop=True)

Step 200 of 200; Loss: 3.0358e+03; Test Loss: 1.9863e+03. (Time: 3.8s)
first loss: inf test loss new: 3113.832800292969
updating best model ..
Step 200 of 100000; Loss: 3.1138e+03. (Time: 0.0s)
Step 200 of 200; Loss: 9.3901e+02; Test Loss: 1.9848e+03. (Time: 2.3s)
first loss: 3113.832800292969 test loss new: 3110.567392578125
updating best model ..
Step 400 of 100000; Loss: 3.1106e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.2713e+03; Test Loss: 1.9834e+03. (Time: 2.3s)
first loss: 3110.567392578125 test loss new: 3107.3826098632812
updating best model ..
Step 600 of 100000; Loss: 3.1074e+03. (Time: 0.0s)
Step 200 of 200; Loss: 3.9798e+03; Test Loss: 1.9821e+03. (Time: 3.3s)
first loss: 3107.3826098632812 test loss new: 3104.287646484375
updating best model ..
Step 800 of 100000; Loss: 3.1043e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.5663e+03; Test Loss: 1.9809e+03. (Time: 2.5s)
first loss: 3104.287646484375 test loss new: 3101.2500610351562
updating best model ..
Step 1000 of 100000; L

KeyboardInterrupt: 

In [16]:
mean,std = compute_log_likelihood(dataset_test, bandits.Hk_PreserveConAgentQ, rl_pc_params)

Average Normalized Likelihood: 48.5%


In [None]:
#@title Set up the RNN (Vanilla RNN) Model
n_hidden = 8
def make_vanilla_rnn():
    model = hk.DeepRNN(
        [hk.VanillaRNN(n_hidden), hk.Linear(output_size=2)]
    )
    return model

#@title Set up the RNN (Vanilla RNN) Model
n_hidden = 8
def make_lstm():
    model = hk.DeepRNN(
        [hk.LSTM(n_hidden), hk.Linear(output_size=2)]
    )
    return model

In [None]:
#@title Fit the RNN (GRU) model
#n_steps_max = 1000000 #@param

gru_params, _, all_losses = rnn_utils.fit_model(
    model_fun=make_lstm,
    dataset_train = dataset_train,
    dataset_test = dataset_validate,
    loss_fun='categorical',
    optimizer= optax.chain(
        optax.add_decayed_weights(1e-5),  # L2 正则化
        optax.adam(learning_rate=1e-4)  # Adam 优化器
    ),
    n_steps_per_call=200,
    n_steps_max=100000,
    return_all_losses=True,
    early_stop_step=200,
    if_early_stop=True)

Step 200 of 200; Loss: 2.6275e+03; Test Loss: 2.5221e+03. (Time: 4.7s)updating best model ..
Step 200 of 100000; Loss: 3.0262e+03. (Time: 0.0s)
Step 200 of 200; Loss: 1.1956e+03; Test Loss: 2.5076e+03. (Time: 3.9s)updating best model ..
Step 400 of 100000; Loss: 3.0114e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.0060e+03; Test Loss: 2.4942e+03. (Time: 4.3s)updating best model ..
Step 600 of 100000; Loss: 2.9995e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.0375e+03; Test Loss: 2.4818e+03. (Time: 4.8s)updating best model ..
Step 800 of 100000; Loss: 2.9901e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.2896e+03; Test Loss: 2.4699e+03. (Time: 3.8s)updating best model ..
Step 1000 of 100000; Loss: 2.9821e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.0236e+03; Test Loss: 2.4578e+03. (Time: 3.9s)updating best model ..
Step 1200 of 100000; Loss: 2.9748e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.1469e+03; Test Loss: 2.4449e+03. (Time: 4.8s)updating best model ..
Step 1400 of 100000; Loss: 2.9673e+03. (T

Exception ignored in: <function _xla_gc_callback at 0x7d103f43cfe0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


Step 200 of 200; Loss: 3.3487e+03; Test Loss: 2.2146e+03. (Time: 4.0s)updating best model ..
Step 7200 of 100000; Loss: 2.8375e+03. (Time: 0.0s)
Step 200 of 200; Loss: 2.6303e+03; Test Loss: 2.2115e+03. (Time: 3.7s)updating best model ..
Step 7400 of 100000; Loss: 2.8362e+03. (Time: 0.0s)
Step 200 of 200; Loss: 1.7915e+03; Test Loss: 2.2091e+03. (Time: 4.1s)updating best model ..
Step 7600 of 100000; Loss: 2.8350e+03. (Time: 0.0s)
Step 200 of 200; Loss: 2.6021e+03; Test Loss: 2.2060e+03. (Time: 3.7s)updating best model ..
Step 7800 of 100000; Loss: 2.8340e+03. (Time: 0.0s)
Step 200 of 200; Loss: 2.1777e+03; Test Loss: 2.2033e+03. (Time: 3.7s)updating best model ..
Step 8000 of 100000; Loss: 2.8327e+03. (Time: 0.0s)
Step 200 of 200; Loss: 1.0539e+03; Test Loss: 2.1991e+03. (Time: 3.9s)updating best model ..
Step 8200 of 100000; Loss: 2.8297e+03. (Time: 0.0s)
Step 200 of 200; Loss: 5.1424e+03; Test Loss: 2.1942e+03. (Time: 3.6s)updating best model ..
Step 8400 of 100000; Loss: 2.8250e+03

KeyboardInterrupt: 

In [None]:
# #@title Compute log-likelihood
# def compute_log_likelihood(dataset, model_fun, params):

#   xs, actual_choices = next(dataset)
#   n_trials_per_session, n_sessions = actual_choices.shape[:2]
#   model_outputs, model_states = rnn_utils.eval_model(model_fun, params, xs)

#   predicted_log_choice_probabilities = np.array(jax.nn.log_softmax(model_outputs[:, :, :2]))

#   log_likelihood = 0
#   n = 0  # Total number of trials across sessions.
#   for sess_i in range(n_sessions):
#     for trial_i in range(n_trials_per_session):
#       actual_choice = int(actual_choices[trial_i, sess_i])
#       if actual_choice >= 0:  # values < 0 are invalid trials which we ignore.
#         log_likelihood += predicted_log_choice_probabilities[trial_i, sess_i, actual_choice]
#         n += 1

#   normalized_likelihood = np.exp(log_likelihood / n)

#   print(f'Normalized Likelihood: {100 * normalized_likelihood:.1f}%')

#   return normalized_likelihood

In [None]:
#@title Compute quality-of-fit: Held-out Normalized Likelihood
# Compute log-likelihood
print('Normalized Likelihoods for GRU')
print('Training Dataset')
training_likelihood = compute_log_likelihood(dataset_train, make_lstm, gru_params)
print('Held-Out Dataset')
testing_likelihood = compute_log_likelihood(dataset_validate, make_lstm, gru_params)

Normalized Likelihoods for GRU
Training Dataset
Average Normalized Likelihood: 61.3%
Held-Out Dataset
Average Normalized Likelihood: 62.0%
