In [38]:
from scipy.io import loadmat
import numpy as np
import os
from tqdm import tqdm
import itertools
import matplotlib.pyplot as plt

In [39]:
# Indexer
class Subjects:
    class Sessions:
        class Runs:

            def __init__(self, paths) -> None:
                self.paths = sorted(paths, key= lambda x: int(x.split('_')[-1][3]))

            def __getitem__(self, run_idx):
                return self.load_data(self.paths[run_idx])
            
            def load_data(self, path):
                mat_contents = loadmat(path, struct_as_record=False, squeeze_me=True)
                return mat_contents["p"].__dict__

        def __init__(self, path) -> None:
            self.base_path = os.path.join(path, "runs")
            self.sess_paths = sorted(os.listdir(self.base_path), key= lambda x: int(x.split('_')[2][1]))
            self.sess_paths = [os.path.join(self.base_path, path) for path in self.sess_paths]
            self.sess_paths = [self.sess_paths[i:i+6] for i in range(0, len(self.sess_paths), 6)]

        def __getitem__(self, sess_idx):
            return self.Runs(self.sess_paths[sess_idx])

    def __init__(self, path) -> None:
        self.path = path
        self.subj_paths = sorted(os.listdir(path), key=lambda x: int(x.split('j')[1]))

    def __getitem__(self, sub_idx):
        return self.Sessions(os.path.join(self.path, self.subj_paths[sub_idx]))

# Refresher
def np_refresh(np_array): return np.array(np_array.tolist())

In [40]:
# Set in stone
PATH_DS = '../axej_eeg'
EXP_ORI = [159, 123, 87, 51, 15]
SUBS = 13
SESS = 4
RUNS = 6

subj = Subjects(PATH_DS)

# Index combinations at 3 levels with ittertools
data_idx = list(itertools.product(range(SUBS), range(SESS), range(RUNS)))

In [41]:
# Expectation
# pcat.tr_exp = pcat.trlabel(pcat.prior == 1);
# pcat.tr_un = pcat.trlabel(pcat.prior == 3);
# pcat.tr_neu = pcat.trlabel(pcat.prior == 2);

# Attention = attCue
# pcat.tr_foc = pcat.trlabel(pcat.attcue == 1); = attCue
# pcat.tr_div = pcat.trlabel(pcat.attcue == 2); = attCue
# div -> -1, for -> 1

# Coherence = tgCoh
# pcat.tr_lo = pcat.trlabel(pcat.moco == 1); = tgCoh
# pcat.tr_hi = pcat.trlabel(pcat.moco == 2); = tgCoh
# low -> -1, high -> 1

# Observation = response_angle

# Target = stimDirREAL
# Coherece oriantation = EXP_ORI[p.expOri]
# Coherence strength = tgCoh
# Attention state = attCue
# bias = np.ones

In [42]:
shape = (SESS, RUNS - 1)

jx = np.ndarray(shape=shape, dtype=object)
jy = np.ndarray(shape=shape, dtype=object)
sti_dir = np.ndarray(shape=shape, dtype=object)
ori_dir = np.ndarray(shape=shape, dtype=object)
ori_st = np.ndarray(shape=shape, dtype=object)
att_st = np.ndarray(shape=shape, dtype=object)

subs = 5
for sess, runs in itertools.product(range(SESS), range(RUNS)):
    if runs == 5: continue
    data = subj[subs][sess][runs]
    jx[sess, runs] = data['joyx']
    jy[sess, runs] = data['joyy']
    sti_dir[sess, runs] = data['stimDirREAL']
    ori_dir[sess, runs] = EXP_ORI[data['expOri'] - 1]
    ori_st[sess, runs] = data['tgCoh']
    att_st[sess, runs] = data['attCue']

jx = np_refresh(jx)
jy = np_refresh(jy)
sti_dir = np_refresh(sti_dir)
ori_dir = np_refresh(ori_dir)
ori_st = np_refresh(ori_st)
att_st = np_refresh(att_st)

# Repeat ori for each trial
ori_dir = np.expand_dims(ori_dir, axis=2)
ori_dir = np.repeat(ori_dir, 120, axis=2)

jx.shape, jy.shape

((4, 5, 120, 500), (4, 5, 120, 500))

In [43]:
sti_dir.shape, ori_dir.shape, ori_st.shape, att_st.shape

((4, 5, 120), (4, 5, 120), (4, 5, 120), (4, 5, 120))

In [123]:
def make_responses(jx, jy):
    dist_from_cent = np.sqrt(jx ** 2 + jy ** 2)
    dist_from_cent[np.isnan(dist_from_cent)] = 0
    max_idx = np.argmax(dist_from_cent, axis=3)

    max_x = jx[np.arange(4)[:, np.newaxis, np.newaxis], np.arange(5)[:, np.newaxis], np.arange(120), max_idx]
    max_y = jy[np.arange(4)[:, np.newaxis, np.newaxis], np.arange(5)[:, np.newaxis], np.arange(120), max_idx]

    resp_angle = np.arctan2(max_y, max_x)
    resp_angle = np.rad2deg(resp_angle)

    return resp_angle, max_idx

In [124]:
resp_angle, resp_idx = make_responses(jx, jy)
resp_angle.shape, resp_idx.shape

((4, 5, 120), (4, 5, 120))

In [153]:
ch = 8

for i in range(50):
    ch = i
    print(f"{resp_angle[1,0,ch]:.0f} - ", sti_dir[1,0,ch])

133 -  159
65 -  15
129 -  159
49 -  51
140 -  159
156 -  159
150 -  159
30 -  159
80 -  87
155 -  159
151 -  159
64 -  159
41 -  15
44 -  51
145 -  159
149 -  159
38 -  15
37 -  159
159 -  159
56 -  159
126 -  15
42 -  159
50 -  159
147 -  159
47 -  159
70 -  123
57 -  159
150 -  159
134 -  159
145 -  159
67 -  87
141 -  159
61 -  159
142 -  159
42 -  159
150 -  159
149 -  159
40 -  159
42 -  159
34 -  159
150 -  159
62 -  159
150 -  159
84 -  87
37 -  159
149 -  159
33 -  159
153 -  123
151 -  159
29 -  51
