<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Chunk-Model" data-toc-modified-id="Chunk-Model-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Chunk Model</a></span></li></ul></div>

In [194]:
# first, test generative model on one dimensional graphs
import sys
sys.path.append('../HCM')
from main import *
import seaborn as sns
import pandas as pd
from functools import partial

import operator
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

## Chunk Model

In [2]:
# demonstrate the graph structure as the number of observation increases

plt.figure(figsize = (10,10))

def plot_chunk_graph(vertex_list, edges, location, ax=None, Print = False):
    chunks = []
    for ck in vertex_list:
        chunks.append(list(np.ravel(tuple_to_arr(ck)).astype(int)))
    if Print:
        print('chunks ', chunks)
        print('edges ', edges)
        print('location ', location)
    location[0] = [1,8]
    element_palette = {0:'#FFFFFF', 1:'#FFCAB1', 2:'#C1D7AE', 3:'#B10FBD',4:'#6C809A'}  # colorblind palette
    chunk_element_offset = {1: [0], 2: [-0.1, 0.1], 3: [-0.2, 0, 0.2], 4: [-0.3, -0.1, 0.1, 0.3],5:[-0.4, -0.2,0, 0.2, 0.4],6:[-0.6,-0.4, -0.2, 0, 0.2,0.4],7:[-0.6,-0.4, -0.2,0, 0.2, 0.4,0.6],8:[-0.8,-0.6,-0.4, -0.2,0,0.2, 0.4,0.6]}


    chunk_element_distance = 0.9
    if ax is None:
        f, ax = plt.subplots()

    chunks_by_length = defaultdict(list)
    for chunk in chunks:
        chunks_by_length[len(chunk)].append(chunk)
    max_width = max(len(v) for v in chunks_by_length.values())

    node_pos = {}
    for i in range(0,len(chunks)):
        chunk = chunks[i]
        chunk_id = chunks.index(chunk)
        chunk_len = len(chunks[chunk_id])
        n_on_level = len(chunks_by_length[chunk_len])
        id_on_level = chunks_by_length[chunk_len].index(chunk)
        xpos = location[i][0]
        ypos = location[i][1]
        node_pos[chunk_id] = (xpos, ypos)

    G = nx.MultiDiGraph()
    for node in range(len(chunks)):
        G.add_node(node, pos=node_pos[node])

    for edge in edges:
        G.add_edge(edge[0], edge[1], weight=0)

    edge_color = ['k' if weight else 'grey' for weight in nx.get_edge_attributes(G,'weight').values()]

    nx.draw_networkx(G, node_pos,
                        node_color='grey',
                        node_size=0.0001,
                        #connectionstyle='arc3, rad = 0.1',
                        arrowsize=15,
                        width=3,
                        with_labels=False,
                        edge_color=edge_color,
                        alpha = 0.3,
                        ax=ax)

    for chunk_id, chunk in enumerate(chunks):

        for element_id, element in enumerate(chunk):
            ax.scatter([node_pos[chunk_id][0] + (chunk_element_offset[len(chunk)][element_id] * chunk_element_distance)],
                        [node_pos[chunk_id][1]],
                        c=element_palette[element % 5],
                        edgecolors = 'lightgrey',
                        linewidth=3,
                        marker='s',
                        s=600
                        #zorder=5
                        )
    plt.axis("off")
    plt.show()

    return ax

<Figure size 1000x1000 with 0 Axes>

# Experimental Data

In [298]:
tbl_e2 = pd.read_csv("../InputData/human_data/gershman-2018-e2.csv")
tbl_rb = pd.read_csv("../InputData/human_data/speekenbrink-konstantinidis-2015.csv")
tbl_rb["subject"] = tbl_rb["id"]

In [299]:
tbl_e2_grouped = tbl_e2.groupby(tbl_e2["subject"])
e2_subjects = list(tbl_e2_grouped.groups.keys())

In [300]:
tbl_rb_grouped = tbl_rb.groupby("id")
rb_subjects = list(tbl_rb_grouped.groups.keys())

In [301]:
choices = tbl_e2_grouped.get_group(e2_subjects[0])["choice"]

Transitions between blocks have to be marked

Random samples of digits are used for that

In [302]:
tbl_design_e2 = tbl_e2[["subject", "block"]].value_counts().reset_index()[["subject", "block"]]
trials_unique = pd.Series(np.concatenate((tbl_e2["trial"].unique(), np.array([11]))), name = "trial")
tbl_design_e2 = tbl_design_e2.merge(trials_unique, how = "cross")

In [303]:
tbl_rb.head()

Unnamed: 0,cond,id,id2,seed,trial,deck,payoff,rt,block,age,gender,trend,volatility,previous_deck,repeat_deck,switch_deck,run_nr,run_length,subject
0,ntn,1,21,1,1,1,-72,1908.4545,1,19,male,Trend,Variance Stable,,0,1,1,1,1
1,ntn,1,21,1,2,2,0,3733.9498,1,19,male,Trend,Variance Stable,1.0,0,1,2,1,1
2,ntn,1,21,1,3,3,0,3162.7877,1,19,male,Trend,Variance Stable,2.0,0,1,3,1,1
3,ntn,1,21,1,4,4,50,3027.3895,1,19,male,Trend,Variance Stable,3.0,0,1,4,1,1
4,ntn,1,21,1,5,4,40,2418.0934,1,19,male,Trend,Variance Stable,4.0,1,0,4,2,1


In [292]:
tbl_e2_extend = tbl_design_e2.merge(tbl_e2[["subject", "block", "trial", "choice", "repeat_choice"]], how = "left")
tbl_e2_extend["repeat_choice"] = tbl_e2_extend["repeat_choice"] + 1
tbl_e2_extend.loc[tbl_e2_extend["trial"] == 1, "repeat_choice"] = 5
tbl_e2_extend.loc[tbl_e2_extend["trial"] == 1, "switch_choice"] = 5

In [304]:
tbl_rb.loc[tbl_rb["trial"] == 1, "repeat_deck"] = 5
tbl_rb.loc[tbl_rb["trial"] == 1, "switch_deck"] = 5

In [282]:
def get_chunks_one_participant(subj_idx, tbl, var):
    # fill block transitions with random integers
    if var in ["choice", "repeat_choice", "repeat_deck", "switch_deck"]:
        choices = tbl[var][tbl["subject"] == subj_idx]
        choices[pd.isna(choices)] = np.random.choice([3, 4], size = choices[pd.isna(choices)].shape[0], replace = True)
        n_token = 4
    elif var == "deck":
        choices = tbl[var][tbl["subject"] == subj_idx]
        n_token = 6
    np_choices = choices.to_numpy().reshape(choices.shape[0], 1, 1)
    # model
    cggt = generative_model_random_combination(D=5, n=n_token)
    cggt = to_chunking_graph(cggt)
    learned_M, _, _, _ = partition_seq_hastily(np_choices, list(cggt.M.keys()))
    cg = Chunking_Graph(DT = 0, theta=1)
    cg = rational_chunking_all_info(np_choices, cg)
    chunks = []
    for ck in cg.vertex_list:
        chunks.append(list(np.ravel(tuple_to_arr(ck)).astype(int)))
    return chunks

In [116]:
def chunk_switches(l_chunks):
    for pos, ch in enumerate(l_chunks):
        switches_pos = np.zeros(len(l_chunks))
        if (len(ch) > 2):
            n_switches = 0
            for idx, item in enumerate(ch[0:-1]):
                n_switches += item != ch[idx+1]
                switches_pos[pos] = n_switches
    return switches_pos

# Chunk Model on Experimental Data

## Bandit Choices in Gershman (2018)

In [196]:
f_partial = partial(get_chunks_one_participant, tbl = tbl_e2_extend, var = "choice")

In [197]:
iterable = map(f_partial, e2_subjects)
l_results_e2_choices = list(iterable)

In [198]:
iterable = map(chunk_switches, l_results_e2_choices)
l_chunk_results = list(iterable)

In [199]:
it = map(lambda x: x.sum(axis=0), l_chunk_results)
chunks = list(it)
chunks_per_participant = [ch > 1 for ch in chunks]
print(str(sum(chunks_per_participant)) + "/" + str(len(l_chunk_results)) + " participants with chunks")

4/44 participants with chunks


In [200]:
chunks_observed =  [(idx,l_results_e2_choices[idx]) for idx, ch in enumerate(chunks_per_participant) if ch == True]

In [201]:
l_switch_chunk = []
for ch_obs in chunks_observed:
    l_switch_chunk.append(
        [ch_obs[1][idx] for idx, ch in enumerate(l_chunk_results[ch_obs[0]] > 1) if ch]
    )

In [202]:
l_switch_chunk

[[[1, 2, 1, 2]], [[1, 2, 1]], [[1, 2, 1, 1, 1, 1]], [[1, 2, 1, 2]]]

## Repeat Choices in Gershman (2018)

In [263]:
f_partial = partial(get_chunks_one_participant, tbl = tbl_e2_extend, var = "repeat_choice")

In [264]:
iterable = map(f_partial, e2_subjects)
l_results_e2_repeat = list(iterable)

  terms = (f_obs_float - f_exp)**2 / f_exp


In [265]:
iterable = map(chunk_switches, l_results_e2_repeat)
l_chunk_results_repeat = list(iterable)

In [268]:
it = map(lambda x: x.sum(axis=0), l_chunk_results_repeat)
chunks = list(it)
chunks_per_participant = [ch > 1 for ch in chunks]
print(str(sum(chunks_per_participant)) + "/" + str(len(l_chunk_results)) + " participants with chunks")

8/44 participants with chunks


In [269]:
chunks_observed =  [(idx,l_results_e2_repeat[idx]) for idx, ch in enumerate(chunks_per_participant) if ch == True]

In [270]:
l_switch_chunk = []
for ch_obs in chunks_observed:
    l_switch_chunk.append(
        [ch_obs[1][idx] for idx, ch in enumerate(l_chunk_results_repeat[ch_obs[0]] > 1) if ch]
    )

In [271]:
l_switch_chunk

[[[5, 1, 2, 2, 2, 2]],
 [[5, 1, 1, 2, 2, 2, 2]],
 [[2, 1, 2, 2, 2, 2]],
 [[5, 1, 2, 2, 2, 2]],
 [[5, 1, 2, 2, 2, 2]],
 [[5, 1, 2, 2]],
 [[3, 5, 1, 1]],
 [[5, 1, 1, 2, 2, 2, 2]]]

## Decks in Speekenbrink & Konstantinidis (2015)

In [307]:
tbl_rb_sub = tbl_rb.loc[(tbl_rb["trend"] == "Trend") & (tbl_rb["volatility"] == "Variance Stable"),]

In [309]:
f_partial = partial(get_chunks_one_participant, tbl = tbl_rb_sub, var = "repeat_deck")

In [310]:
iterable = map(f_partial, rb_subjects)
l_results_rb_decks = list(iterable)

In [311]:
tbl_rb

Unnamed: 0,cond,id,id2,seed,trial,deck,payoff,rt,block,age,gender,trend,volatility,previous_deck,repeat_deck,switch_deck,run_nr,run_length,subject
0,ntn,1,21,1,1,1,-72,1908.4545,1,19,male,Trend,Variance Stable,,5,5,1,1,1
1,ntn,1,21,1,2,2,0,3733.9498,1,19,male,Trend,Variance Stable,1.0,0,1,2,1,1
2,ntn,1,21,1,3,3,0,3162.7877,1,19,male,Trend,Variance Stable,2.0,0,1,3,1,1
3,ntn,1,21,1,4,4,50,3027.3895,1,19,male,Trend,Variance Stable,3.0,0,1,4,1,1
4,ntn,1,21,1,5,4,40,2418.0934,1,19,male,Trend,Variance Stable,4.0,1,0,4,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15995,ts,20,60,4,196,2,31,1729.2862,8,20,female,No Trend,Variance Changes,2.0,1,0,3861,9,20
15996,ts,20,60,4,197,2,20,2006.1217,8,20,female,No Trend,Variance Changes,2.0,1,0,3861,10,20
15997,ts,20,60,4,198,1,-15,2412.5147,8,20,female,No Trend,Variance Changes,2.0,0,1,3862,1,20
15998,ts,20,60,4,199,2,16,1606.3952,8,20,female,No Trend,Variance Changes,1.0,0,1,3863,1,20


In [312]:
l_results_rb_decks

[[[5], [0], [1], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5],
  [1],
  [0],
  [1, 1],
  [1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
 [[5],
  [1],
  [0],
  [1, 1],
  [1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
 [[5],
  [0],
  [1],
  [1, 1],
  [1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
 [[5],
  [0],
  [1],
  [1, 1],
  [1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
 [[5], [0], [1], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5], [0], [1], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5], [0], [1], [1, 1]],
 [[5], [0], [1], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5], [0], [1], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5], [1], [0], [1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]],
 [[5],
  [0],
  [1],
  [1, 1],
  [1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1

In [265]:
iterable = map(chunk_switches, l_results_e2_repeat)
l_chunk_results_repeat = list(iterable)