In [34]:
from lib.DataHandler import DataAcquisitionHandler
from abc import abstractmethod
import pickle
import torch

In [2]:
filename = "C:/Users/c25th/code/P300_BCI_Speller/data/box_data/handler_box_data_full_Oct_30_2023.pkl"
with open(filename, 'rb') as f:
            handler = pickle.load(f)
print(handler.get_data())

{'box_data': [{'metadata': {'start_time': 1698727569.2188206, 'length': 254.8039038181305, 'trials': [{'timestamp': (2.417940378189087, 3.4181697368621826), 'label': False}, {'timestamp': (5.535764455795288, 6.565990447998047), 'label': True}, {'timestamp': (8.382052183151245, 9.412099123001099), 'label': True}, {'timestamp': (11.931045770645142, 12.931427955627441), 'label': False}, {'timestamp': (14.546231746673584, 15.574142932891846), 'label': True}, {'timestamp': (17.18848180770874, 18.232096910476685), 'label': True}, {'timestamp': (19.948883056640625, 20.978745698928833), 'label': True}, {'timestamp': (22.693714141845703, 23.725579977035522), 'label': True}, {'timestamp': (25.43831443786621, 26.475011348724365), 'label': True}, {'timestamp': (28.794172048568726, 29.79514789581299), 'label': False}, {'timestamp': (31.410541534423828, 32.44038248062134), 'label': True}, {'timestamp': (34.357948541641235, 35.358147621154785), 'label': False}, {'timestamp': (36.96936321258545, 37.97

In [35]:
class DataObject:


    def __init__(self, data_dict):

        self.keyboard_sessions = []

        if 'keyboard_data' in data_dict.keys():
            keyboard_session_list = data_dict['keyboard_data']

            for session_dict in keyboard_session_list:
                self.keyboard_sessions.append(SessionData(session_dict=session_dict, type='keyboard'))

        self.box_sessions = []

        if 'box_data' in data_dict.keys():
            box_session_list = data_dict['box_data']

            for session_dict in box_session_list:
                self.box_sessions.append(SessionData(session_dict=session_dict, type='box'))


    def get_data(self, visitor=None, type='box'):
        if visitor is None:
            keyboard_data = self.keyboard_sessions
        else:
            keyboard_data = visitor.visit_data_object(object=self, type=type)
        return keyboard_data




class SessionData:
    
    def __init__(self, session_dict, type):
        metadata = session_dict['metadata']
        self.data = session_dict['data']
        self.start_time = metadata['start_time']
        self.length = metadata['length']
        self.flash_time_range = metadata['flash_time_range']
        self.sample_time = metadata['sample_time']
        self.description = metadata['description']
        self.type = type
        # TODO: Test that the data loads the description properly

        self.trials = []
        for trial in metadata['trials']:
            if self.type == 'keyboard':
                self.trials.append(KeyboardTrialData(trial_dict=trial, parent_session=self))
            elif self.type == 'box':
                self.trials.append(BoxTrialData(trial_dict=trial, parent_session=self))
            else:
                raise ValueError('Session type must be either keyboard or box')

    def get_data(self, visitor=None):
        if visitor is None:
            trials = self.trials
        else:
            trials = visitor.visit_session_data(session=self)
        return trials



class TrialData:

    def __init__(self, trial_dict, parent_session):
        self.timestamp = trial_dict['timestamp']
        self.label = trial_dict['label']
        self.parent_session = parent_session

    def get_data(self, visitor=None):
        if visitor is None:
            data = self.data
        else:
            data = visitor.visit_trial_data(trial=self)
        return data
    


class BoxTrialData(TrialData):

    def __init__(self, trial_dict, parent_session):
        super().__init__(trial_dict, parent_session)



class KeyboardTrialData(TrialData):

    def __init__(self, trial_dict, parent_session):
        super().__init__(trial_dict, parent_session)
        self.pattern = trial_dict['pattern']
        self.letter = trial_dict['letter']

In [46]:
class DataDecorator(object):

    @staticmethod
    def visit_data_object(object):
        raise NotImplementedError
    
    @staticmethod
    def visit_session_data(session):
        raise NotImplementedError
    
    @staticmethod
    def visit_trial_data(trial, raw_data):
        raise NotImplementedError


##########################################################################################


class MakeWindowsDataDecorator(DataDecorator):
    """
    Returns a list of tuples ( data_window, label )
    """

    def __init__(self):
        pass


    def visit_data_object(self, object, type):
        """
        Returns a list of tuples ( data_windows, label )
        """
        sessions = object.keyboard_sessions if type == 'keyboard' else object.box_sessions

        data = []
        for session in sessions:
            data.extend(self.visit_session_data(session=session))
        return data
    

    def visit_session_data(self, session):
        """
        Returns a list of tuples ( data_windows, label )
        """
        data = []
        for trial in session.trials:
            data.append(self.visit_trial_data(trial=trial))
        return data
    

    def visit_trial_data(self, trial):
        """
        Returns a tuple ( data_windows, label )
        """
        start = trial.timestamp[0]
        end = start + trial.parent_session.sample_time

        # Make it an index for the data rather than a time
        data_len = len(trial.parent_session.data[0])
        start = (int) ( (start / trial.parent_session.length) * data_len )
        end = (int) ( (end / trial.parent_session.length) * data_len )

        window = self.transform_window(trial.parent_session.data[:, start:end])
        label = self.transform_label(trial.label)

        return ( window, label )
    
    def transform_window(window):
        return window
    
    def transform_label(label):
        return label
    

class MakeTensorWindowsDataDecorator(MakeWindowsDataDecorator):
    """
    Returns a list of tuples ( data_window, label ) for each window in the format of a pytorch tensor
    """

    @staticmethod
    def transform_label(label):
        if not label:
            label = 0
        elif label:
            label = 1
        else:
            raise ValueError('MakeTensorWindowsDataVisitor object/ transform_label method: Label must be a boolean')
        
        return torch.tensor(label)
    
    @staticmethod
    def transform_window(window):
        return torch.from_numpy(window).float()

In [47]:
# TODO: Make visitor template and a filtering visitor object

# Test the MakeDataWindowsVisitor
data = DataObject(handler.get_data())

formatted_data = data.get_data(visitor=MakeTensorWindowsDataDecorator(), type='box')

print(len(formatted_data))
print(len(formatted_data[0]))
print(len(formatted_data[0][0]))
print(len(formatted_data[0][0][0]))
print(len(formatted_data[1][0][0]))

83
2
24
248
248


In [48]:
print(formatted_data[0][1])

tensor(0)
