In [156]:
import itertools
import copy
import collections
import numpy as np
from numpy.random import default_rng
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import pickle
import random
from sklearn.model_selection import train_test_split
import sys
import os
current = os.path.dirname(os.path.realpath('plotting.py'))
parent = os.path.dirname(current)
sys.path.append(parent)
import functions.plotting as NNplt
from functions.rnn_cryptic import generate_sequences, convert_seq2inputs, calculate_output, pad_seqs
random.seed(5)

 # Functions for test sets

In [157]:
default_cues = {'A': 2, 
                'B': 3,
                'C': 5,
                'D': 7,
                'E': 1, 
                'F': 4,
                'G': 9,
                'H': 11}
ops = ['+', '*', '-']
inputs = ['A', 'B', 'C', 'D']

all_combos = []
for op in ops:
    for inp in inputs:
        all_combos.append((op, inp))

# test set for input-operation recombination
# test on unseen input-operation combinations

def test_combos(trainset):
    train_combos =[]
    for t in trainset:
        train_combos.append(t[1])
        train_combos.append(t[2])
    combos_test = all_combos.copy()
    for m in train_combos:
        combos_test.remove(m)

    random.shuffle(combos_test)
    numpairs = int(len(combos_test)/2)
    test_inputs = int(numpairs/4)*inputs
    random.shuffle(test_inputs)

    testset = []
    for i in range(numpairs):
        trial = [inputs[i]]
        trial.append(combos_test[2*i])
        trial.append(combos_test[2*i+1])
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        testset.append(trial)
    
    return testset

# test set order
def test_order(trainset):
    pairs = []
    for t in trainset:
        pairs.append(tuple(t[1:-1]))
    
    train_combos =[] # all input-cue combos in original set
    for t in trainset:
        train_combos.append(t[1])
        train_combos.append(t[2])

    all_train_pairs = list(itertools.product(train_combos, repeat=2)) # all permutations
    unique_train = list(dict.fromkeys(all_train_pairs)) # only unique
    test_pairs = unique_train.copy()
    for m in pairs:
        test_pairs.remove(m) # remove permutations in test set

    select_pairs = random.sample(test_pairs, len(trainset)) # randomly select subset equal to trainset size
    
    testset = []
    for t in select_pairs:
        trial = [random.choice(inputs)]
        trial.append(t[0])
        trial.append(t[1])
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        testset.append(trial)
        
    return testset

# init trials
def test_init(trainset):

    testset = []
    for t in trainset:
        inps = inputs.copy()
        inps.remove(t[0])
        inp = random.choice(inps)
        trial = [inp]
        trial.append(t[1])
        trial.append(t[2])
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        testset.append(trial)
    
    return testset

def test_init_small(trainset):

    testset = []
    for t in trainset:
        inps = inputs.copy()
        inps.remove(t[0])
        trial = [random.choice(inps)]
        trial.append(t[1])
        trial.append(t[2])
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        testset.append(trial)

    return testset


# Set M

In [125]:
# train set: M
# minimal set. Each operation and cue appears once (or twice)

Mset = [['D', ('+', 'B'), ('-', 'C')],\
        ['C', ('*', 'A'), ('+', 'D')]]

for trial in Mset:
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)

    
Mcombos = test_combos(Mset) # unseen input-op combos
Morder = test_order(Mset) # new orders of trained input-op combos
Minit = test_init(Mset)

In [130]:
# save trials
savepath = '../sequences/training/'

with open(savepath + 'Mset', 'wb') as f:
    pickle.dump(Mset, f)
with open(savepath + 'Mcombos', 'wb') as f:
    pickle.dump(Mcombos, f)
with open(savepath + 'Morder', 'wb') as f:
    pickle.dump(Morder, f)
with open(savepath + 'Minit', 'wb') as f:
    pickle.dump(Minit, f)

# Set MC

In [131]:
num_trials = int(len(all_combos)/2)

random.shuffle(all_combos)
MCset = []
for i in range(num_trials):
    trial = [random.choice(inputs)]
    trial.append(all_combos[2*i])
    trial.append(all_combos[2*i+1])
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    MCset.append(trial)

MCorder = test_order(MCset)
MCinit = test_init(MCset)

In [132]:
# save
with open(savepath + 'MCset', 'wb') as f:
    pickle.dump(MCtrain, f)
with open(savepath + 'MCorder', 'wb') as f:
    pickle.dump(MCorder, f)
with open(savepath + 'MCinit', 'wb') as f:
    pickle.dump(MCinit, f)

# Set DC

In [133]:
DCset = MCset.copy()
for t in MCset:
    trial = [random.choice(inputs)]
    trial.append(t[2])
    trial.append(t[1])
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    DCset.append(trial)

DCorder = test_order(DCset)
DCinit = test_init(DCset)

In [134]:
# save

with open(savepath + 'DCset', 'wb') as f:
    pickle.dump(DCtrain, f)
with open(savepath + 'DCorder', 'wb') as f:
    pickle.dump(DCorder, f)
with open(savepath + 'DCinit', 'wb') as f:
    pickle.dump(DCinit, f)

# Set F

In [163]:
all_combo_pairs = input_combos = list(itertools.product(all_combos, repeat=2))
Fset = []
for pair in all_combo_pairs:
    trial = [random.choice(inputs)]
    trial.append(pair[0])
    trial.append(pair[1])
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    Fset.append(trial)
    
Finit = test_init(Fset)

In [164]:
with open(savepath + 'Fset', 'wb') as f:
    pickle.dump(Fset, f)
with open(savepath + 'Finit', 'wb') as f:
    pickle.dump(Finit, f)

# Pretraining

In [154]:
# Type 1: map inputs to scalar value

type1 = []
for inp in inputs:
    trial = [inp]
    trial = trial + [('X','X')]*2
    trial_output = default_cues[inp]
    trial.append(trial_output)
    type1.append(trial)

# Type 2: 1 step operations
#Type2_3: each operation appears once

type2_3 = []
for op in ops:
    trial = [random.choice(inputs)]
    trial.append((op, random.choice(inputs)))
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    type2_3.append(trial)

#Type2_12: each input-operation combo appears once - bigrams
type2_12 = []
for combo in all_combos:
    trial = [random.choice(inputs)]
    trial.append(combo)
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    type2_12.append(trial)
    
#Type2_48: every combination of input and input-operation combo - trigram

type2_48 = []
for combo in all_combos:
    for inp in inputs:
        trial = [inp]
        trial.append(combo)
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        type2_48.append(trial)

# Type3: Single input, single operation
len_seq=2
type3 = []
for inp in inputs:
    for op in ops:
        trial = generate_sequences(operators=[op], input_ids = [inp], len_seq = len_seq,\
                                   cue_dict = default_cues, init_values = [inp])
        type3.append(trial)

# Type4: Fixed input, all op combos

op_combos = list(itertools.product(ops, repeat=2))

type4 = []
for oc in op_combos:
    inp = random.choice(inputs)
    trial = [inp]
    trial.append((oc[0], inp))
    trial.append((oc[1], inp))
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    type4.append(trial)
    
op_combos = list(itertools.product(ops, repeat=2))

type4_all = []
for oc in op_combos:
    for inp in inputs:
        trial = [inp]
        trial.append((oc[0], inp))
        trial.append((oc[1], inp))
        trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
        trial.append(trial_output)
        type4_all.append(trial)

    
# Type5: Fixed op, all input combos

input_combos = list(itertools.product(inputs, repeat=3))

type5 = []
for combo in input_combos:
    op = random.choice(ops)
    trial = [combo[0]]
    trial.append((op, combo[1]))
    trial.append((op, combo[2]))
    trial_output = calculate_output(trial, cue_dict = default_cues, bidmas = False)
    trial.append(trial_output)
    type5.append(trial)


In [155]:
## save
savepath = '../sequences/pretraining/'

type2_3 = pad_seqs(type2_3)
type2_12 = pad_seqs(type2_12)
type2_48 = pad_seqs(type2_48)
type3 = pad_seqs(type3)
type4 = pad_seqs(type4)
type5 = pad_seqs(type5)


with open(savepath + 'type1', 'wb') as f:
    pickle.dump(type1, f)
with open(savepath + 'type2_3', 'wb') as f:
    pickle.dump(type2_3, f)
with open(savepath + 'type2_12', 'wb') as f:
    pickle.dump(type2_12, f)
with open(savepath + 'type2_48', 'wb') as f:
    pickle.dump(type2_48, f)
with open(savepath + 'type3', 'wb') as f:
    pickle.dump(type3, f)
with open(savepath + 'type4', 'wb') as f:
    pickle.dump(type4, f)
with open(savepath + 'type5', 'wb') as f:
    pickle.dump(type5, f)