In [1]:
import torch as tr
import numpy as np

import itertools

from PM_models import *
from PM_tasks import *
from help_amtask import *

import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline

%load_ext autoreload
%reload_ext autoreload
%autoreload 2

# goal
- investigate CRs suited for organizing retrieval in the switching arbitrary maps task

In [65]:
class AnalyticPM():
  ''' 
  s/abags: all possible stimuli and actions
  mapbag: maps active within a block
  '''
  def __init__(self,nmaps_per_block=3):
    # initialize EM
    self.flushEM()
    # bags
    self.sbag = [0,1,2]
    self.abag = [10,11,12]
    self.fix_mapbag(nmaps_per_block)
    # embedding matrices
    self.context_emat = np.eye(len(mapbag))
    self.stim_emat = np.eye(len(sbag))
    return None
  
  def fix_mapbag(self,nmaps_per_block):
    ''' mapbag: list of maps '''
    asets = [i for i in itertools.permutations(self.abag,nmaps_per_block)]
    ssets = [i for i in itertools.combinations(self.sbag,nmaps_per_block)]
    self.mapbag = [{s:a for s,a in zip(sset,aset)} for sset,aset in itertools.product(ssets,asets)]
    return mapbag
  
  def encode(self,map_idx):
    ''' encode single map '''
    assert map_idx <= len(self.mapbag)
    map_b = self.mapbag[map_idx]
    context_b = self.context_emat[map_idx]
    # EM = [c_embed,s_embed : act_int]
    for stim_idx,act in map_b.items():
      stim_embed = self.stim_emat[stim_idx]
      emk = [context_b,stim_embed]
      self.EMK.append(emk)
      self.EMV.append(act)  
    return None
  
  def flushEM(self):
    self.EMK = []
    self.EMV = []
    return None

In [66]:
pm = AnalyticPM()

In [68]:
pm.mapbag

[{0: 10, 1: 11, 2: 12},
 {0: 10, 1: 12, 2: 11},
 {0: 11, 1: 10, 2: 12},
 {0: 11, 1: 12, 2: 10},
 {0: 12, 1: 10, 2: 11},
 {0: 12, 1: 11, 2: 10}]

In [71]:
pm.encode(2)
pm.EMV

[10, 12, 11, 10, 12, 11, 11, 10, 12]

In [41]:
''' generate data '''
block_len = 5
nblocks = 2

# loop over blocks
for blocknum in range(nblocks):
  # sample map of block and corresponding context vector
  mapidx = np.random.randint(len(mapbag))
  map_b = mapbag[mapidx]
  context_b = context_emat[mapidx]
  print(map_b,context_b)
  # stim and action sets
  sset_b = list(map_b.keys())
  aset_b = list(map_b.values())
  ## instruction / encoding phase
  
  ## test / retrieval phase
  for trial in range(block_len):
    idx_stim_t = np.random.choice(sset_b)
    stim_t = stim_emat[idx_stim_t]
    print('s_t',stim_t)

{0: 10, 1: 11, 2: 12} [1. 0. 0. 0. 0. 0.]
s_t [0. 0. 1.]
s_t [0. 1. 0.]
s_t [1. 0. 0.]
s_t [0. 0. 1.]
s_t [0. 0. 1.]
{0: 10, 1: 11, 2: 12} [1. 0. 0. 0. 0. 0.]
s_t [1. 0. 0.]
s_t [0. 1. 0.]
s_t [0. 0. 1.]
s_t [0. 0. 1.]
s_t [0. 0. 1.]


### cannonical representations: 
- low dimensional
- structures mirrors environment structure
    - gibsonian? 

### task
- percept stimulus generated by environment object has different dimensions

### paradigm:

- pomdp vs mdp: 
    - mdp: environment feature that differentiates buckets persists
    - pomdp: environment surface level differentiates 

### EM system
- encoding method
    - 
- retreival function
    - retrieved memory = f (stimulus history) 
        - stimulus history is list
    - uses distance between current and encoded states
        - iteratively compute distance between current state and items in stimhistL

- recency effect strongly suggests smoothly decaying 
- probabilistic? 

# environment

- instruction phase: 
    - #nmaps stimulus, action pairs presented 
    - model encodes
    
- how will I encode map?
    - 
    - specified by environment as dict
    - model has encode map method
        - map_i_representation = encode_map(stimulus_i, action_i)

- test phase:
    - sequentially process #block_len stimuli (no action) 
    - model retrieves, responds

- percept is vector 
    - subfields:
        - e: external (map @instruction, stimulus @test)
            - provided by environment
        - i: internal (system state)
            - integral of e
        - c_e: context_e
            - environment auto-correlation 
                - recency effect
        - c_i: context_i      
            - change point detection 
                - shifting property
                - "bucket"

- blocks within experiment can have repeated, overlapping or independent mapsets
    - curriculum defines mapset sequence used within an experiment
    - mapset is dict {stim:action}
    - curriculum is 


In [None]:
# gen experiment (nblocks,curriculum)
# gen block 
# gen encoding (nmaps)
# gen test (num blocks)

# model
- methods
    - map_representation = encode_map(stimulus,action)
        - iteratively compute individual encoding for each map in mapset
        - or compute a single conjunctive representation of mapset
    - retrieved_memory = retrieve(percept(stimulus,context))
    
- keeps track of internal context
    - or should this be handled by environment?
- 
   

In [12]:
task = TaskArbitraryMaps(2,switchmaps=True)

In [27]:
t,i,x,y = task.gen_ep_data(3,2,return_trial_flag=True)
t

tensor([[0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2]])