In [1]:
%load_ext autoreload
%autoreload 2

import os
import itertools 
from os.path import join
from pdb import set_trace

import numpy as np
import pandas as pd
import mne
from IPython.display import display, clear_output

import deepbci as dbci
from deepbci.utils import utils
from deepbci.data_utils.data import run_group_mutators

import vis_utils as vtils

In [2]:
data_cfg_name = 'data-load.yaml'

data_cfg_path = utils.path_to(os.getcwd(), data_cfg_name)
data_cfg = utils.load_yaml(data_cfg_path)

vtils.clean_config(data_cfg, keep_keys=['groups', 'mutate'])

data_cfg

{'groups': {'_target_': 'deepbci.data_utils.Groups',
  'data_groups': {'dbci': [{'_target_': 'deepbci.data_utils.data_loaders.load_data',
     'load_method': 'load_to_memory',
     'load_method_kwargs': {'subjects': [1],
      'trials': [1],
      'data_file': 'eeg.csv',
      'true_fs': False,
      'load_state_info': True,
      'load_state_info_kwargs': {'clean': True},
      'preload_epoch_indexes': {'generate_sync_epochs': None}},
     'data_loader': {'_target_': 'deepbci.data_utils.data_loaders.OAOutLoader'}}]}},
 'mutate': None}

In [3]:
mne.set_log_level('ERROR')
grps = vtils.instantiate_and_mutate(data_cfg)

Checking if .npy file exists...
Attempting to decompress zstd file...
Attempting to decompress /home/dev/mnt/deepbci/data/obstacle_avoidance/outcome/S1/trial-1/states/state-images.npy.zst...
Decompression was successful!
Decompressed zstd file
Loading .npy file
Removing /home/dev/mnt/deepbci/data/obstacle_avoidance/outcome/S1/trial-1/states/state-images.npy


In [4]:
grps

group  dataset  subject  trial
dbci   OAOut    1        1        DataGroup(RawArray, ndarray)
Name: data, dtype: object

In [5]:
grps[['dbci'], ['OAOut', 'BGSInt'], ['1'], ['1']]

  selected = self.data_map.loc[index]


array([DataGroup(RawArray, ndarray)], dtype=object)

In [6]:
grps['dbci', ['OAOut'], ['1'], ['1']][0].metadata

defaultdict(list,
            {'state_info': [(      timestamps  actions  rewards
               0       0.018027        0        0
               1       0.111643        0        0
               2       0.210284        0        0
               3       0.303361        0        0
               4       0.400361        0        0
               ...          ...      ...      ...
               1795  173.606691        0        0
               1796  173.705417        0        0
               1797  173.804037        0        0
               1798  173.895832        1        0
               1799  173.992741        0        0
               
               [1800 rows x 3 columns],
               array([[[[255, 255, 255],
                        [255, 255, 255],
                        [255, 255, 255],
                        ...,
                        [255, 255, 255],
                        [255, 255, 255],
                        [255, 255, 255]],
               
                    

In [7]:
state_info = grps['dbci', ['OAOut'], ['1'], ['1']][0].metadata['state_info']
state_info

[(      timestamps  actions  rewards
  0       0.018027        0        0
  1       0.111643        0        0
  2       0.210284        0        0
  3       0.303361        0        0
  4       0.400361        0        0
  ...          ...      ...      ...
  1795  173.606691        0        0
  1796  173.705417        0        0
  1797  173.804037        0        0
  1798  173.895832        1        0
  1799  173.992741        0        0
  
  [1800 rows x 3 columns],
  array([[[[255, 255, 255],
           [255, 255, 255],
           [255, 255, 255],
           ...,
           [255, 255, 255],
           [255, 255, 255],
           [255, 255, 255]],
  
          [[255, 255, 255],
           [255, 255, 255],
           [255, 255, 255],
           ...,
           [255, 255, 255],
           [255, 255, 255],
           [255, 255, 255]],
  
          [[255, 255, 255],
           [255, 255, 255],
           [255, 255, 255],
           ...,
           [255, 255, 255],
           [255, 255, 

The below code shows all the state info for a given trial. Notice that the state info data always contains 'actions' but depending on if the data is from OA or BGS it might have a second column called 'rewards' or 'terminal'.

In [8]:
state_info[0][0]

Unnamed: 0,timestamps,actions,rewards
0,0.018027,0,0
1,0.111643,0,0
2,0.210284,0,0
3,0.303361,0,0
4,0.400361,0,0
...,...,...,...
1795,173.606691,0,0
1796,173.705417,0,0
1797,173.804037,0,0
1798,173.895832,1,0


We can see all the states for the given trial by using the below code. We can also see the deminsions for each state.

In [9]:
state_info[0][1].shape

(1800, 420, 420, 3)

Below, we can print out a single state and it's corresponding info. Change the `image_number` value to change which state we look at.

In [10]:
import ipywidgets as widgets
from PIL import Image


state_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(state_info[0][0]),
    step=1,
    description='State Number:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

@widgets.interact(loc=state_slider)
def index_state_info(loc):
    print(f"State info:\n{state_info[0][0].iloc[loc]}")
    img = Image.fromarray(state_info[0][1][loc])
    display(img)


interactive(children=(IntSlider(value=0, continuous_update=False, description='State Number:', max=1800), Outp…