In [8]:
import websocket
import json
import gym
from gym import spaces
import numpy as np
import numexpr as ne

In [2]:
websocket.enableTrace(True)
ws=websocket.WebSocket()
ws.connect("ws://10.42.0.231:80/")

--- request header ---
GET / HTTP/1.1
Upgrade: websocket
Host: 10.42.0.231
Origin: http://10.42.0.231
Sec-WebSocket-Key: PgnSi1ceFBvRzDuwSFnwQw==
Sec-WebSocket-Version: 13
Connection: Upgrade


-----------------------
--- response header ---


HTTP/1.1 101 Switching Protocols
Server: arduino-WebSocketsServer
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Version: 13
Sec-WebSocket-Accept: JdBbRrC9tQXyYe9p4BK7MGIuNZ8=
-----------------------


In [56]:
range(0, 100)

range(0, 100)

In [60]:
import numpy as np
np.array([[1,2,3],[1,2,3]]).shape

(2, 3)

In [12]:
import websocket
import json
import gym
from gym import spaces
import numpy as np
import re
import numexpr as ne
import time
#steps are not implemented for now
out_dict={'leddelay':{'names':['leddelay'], 'value_range':{'min':10, 'max':10000, 'step':100}},
          'ledcontrols':{'names':['lv1r','lv1g','lv1b','lv2r','lv2g','lv2b','lv3r','lv3g','lv3b','lv4r','lv4g','lv4b', 'lv5r','lv5g','lv5b','lv6r','lv6g',
          'lv6b',
          'lv7r',
          'lv7g',
          'lv7b',
          'lv8r',
          'lv8g',
          'lv8b'], 'value_range':{'min':10, 'max':255, 'step':10}},
          'sound_wave_frequencies':{'names':['wave_1_freq','wave_2_freq'], 'value_range':{'min':1, 'max':30000, 'step':100}},
          'panner_phasor_frequencies':{'names':['panner_freq', 'phasor_1_freq', 'phasor_2_freq','phasor_1_min',  'phasor_2_min', 'phasor_1_dif', 'phasor_2_dif'],  'value_range':{'min':1, 'max':50, 'step':1}},
          'panner_div':{'names':['panner_div'], 'value_range':{'min':1, 'max':5, 'step':1}},
          'sound_wave_shapes':{'names':['wave_1_type', 'wave_2_type'], 'value_range':{'min':0, 'max':3, 'step':1}},
          'maxibolume':{'names':['maxivolume'], 'value_range':{'min':0, 'max':100, 'step':10}}
}
out_order=['lv1r','lv1g','lv1b','lv2r','lv2g','lv2b','lv3r','lv3g','lv3b','lv4r','lv4g','lv4b','lv5r','lv5g','lv5b','lv6r','lv6g','lv6b',
          'lv7r','lv7g','lv7b','lv8r','lv8g','lv8b','leddelay','wave_1_freq','wave_2_freq','panner_freq','panner_div','phasor_1_freq',
          'phasor_1_min','phasor_1_max','phasor_2_freq','phasor_2_min','phasor_2_max','maxivolume','wave_1_type','wave_2_type']


class SFSystemCommunicator(gym.Env):
    def __init__(self, out_dict=out_dict, out_order=out_order,n_input_channels=8, channels_of_interest_inds=list(range(8)), n_timepoints_per_sample=100, max_sfsystem_output=1023,reward_formula_string='(fbin_05_4_ch0+freq_30_ch0)/fbin_12_30_ch0', 
                 fbins=[(0.001, 0.5), (0.5, 4), (4,8), (8,12), (12,30)], delay=10,
                 use_raw_in_os_def=False, use_freq_in_os_def=False, use_fbins_in_os_def=False, device_address="ws://10.42.0.231:80/",
                 step_stim_length=10000):
        self.device_address=device_address
        self.step_stim_length=step_stim_length
        
        self.out_dict=out_dict
        self.out_order=out_order
        self.max_sfsystem_output=max_sfsystem_output
        self.n_timepoints_per_sample=n_timepoints_per_sample
        self.n_input_channels=n_input_channels
        self.reward_formula_string=reward_formula_string

        self.delay=delay

        self.record_raw=use_raw_in_os_def

        self.do_fft=use_freq_in_os_def
        self.record_fft=use_freq_in_os_def

        self.do_fbins=use_fbins_in_os_def
        self.record_fbins=use_fbins_in_os_def

        self.fbins=fbins
        self.n_fbins=len(self.fbins)

        self.channels_of_interest_inds=channels_of_interest_inds
        self.n_channels_of_interest=len(self.channels_of_interest_inds)
    

        self.set_fft_params()

        self.timesleep_period=0.1

        if 'raw' in self.reward_formula_string:
            self.record_raw=True
        if 'freq' in self.reward_formula_string:
            self.record_fft=True
        if 'fbin' in self.reward_formula_string:
            self.record_fbins=True

        if self.record_fbins:
            self.do_fft=True
            self.do_fbins=True
        
        self.init_action_space()
        self.init_observation_space()
        self.set_value_dict_for_reward_function()
        self.connect()
        print(self.connection_status)
        self.set_delay_and_data_transfer_buffer_size()
        print('Delay and data transfer buffer size are set up.')
    def print_reward_formula_legend(self):
        print('Reward formula can use the following operators: //, *, **, -, +')
        print('It can refer to channels using "ch" prefix followed by an index (starting with 0) e.g. ch0')
        print('It can refer to values of frequency bins in specific channels e.g. fbin_10_50_ch0')
        print('The corresponding freqency bins must be present among fbins passed at the initialization step')
        print('fbins should be passed in the form of [(b1 min, b1max),...(bn min, bn max)]')
        print('It can refer to specific frequencies from fft e.g. freq_50_ch0')
        print('For fractional frequencies (e.g. 0.5) do not use points, just start with 0 (e.g. 05)')
        print('Some examples:')
        print('1. (freq_50_ch0+fbin_0_10_ch0)/(fbin_20_30_ch0)')
        print('2. freq_5_ch0/freq_10_ch0')
        print('3. fbin_05_5_ch0')
    def set_value_dict_for_reward_function(self):
        ftokens=re.split(r'[+/)(*]+',self.reward_formula_string)
        self.rewarddict={}
        self.tokendict={}
        for token in ftokens:
            if 'ch' in token:
                self.rewarddict[token]=None
                self.tokendict[token]={}
                subtokens=token.split('_')
                self.tokendict[token]['datatype']=subtokens[0]
                for subtoken in subtokens:
                    if 'ch' in subtoken:
                        self.tokendict[token]['channelindex']=int(subtoken.split('h')[1])
                    if subtoken=='freq':
                        tfreq=float(subtokens[1])
                        self.tokendict[token]['freqdata']=tfreq
                        closestind=np.argmin(np.abs(self.f_plot - tfreq))
                        self.tokendict[token]['closest_fft_ind']=closestind
                        print(f'Token {token}:')
                        print(f'Closest fft frequency {self.f_plot[closestind]}')
                    if subtoken=='fbin':
                        bin_lst=[subtokens[1],subtokens[2]]
                        for i in range(2):
                            val=bin_lst[i]
                            if val.startswith('0'):
                                val=float('0.'+val[1:])
                            else:
                                val=float(val)
                            bin_lst[i]=val
                        self.tokendict[token]['freqdata']=tuple(bin_lst)
                        self.tokendict[token]['fbin_idx']=self.fbins.index(self.tokendict[token]['freqdata'])
    def populate_rewarddict(self, observations):
        for token, tokendata in self.tokendict.items():
            tartype=tokendata['datatype']
            tarchannelidx=tokendata['channelindex']
            if tartype=='raw':
                tarobs=observations['raw_data']
                tarobs=tarobs[:,tarchannelidx]
                res=np.mean(tarobs) #here we use the mean, but this may be changed
            if tartype=='freq':
                #tarfreq=tokendata['freqdata']
                taridx=tokendata['closest_fft_ind']
                tarobs=observations['fft'][tarchannelidx]
                res=tarobs[taridx]
            if tartype=='fbin':
                tarobs=observations['fbins'][tarchannelidx]
                res=tarobs[tokendata['fbin_idx']]
            self.rewarddict[token]=res
    def get_reward(self, observations=None, toreturn=False):
        if observations == None:
            observations=self.observation_space.sample() #if no observations are given extrernally, sample from the observation space
        self.populate_rewarddict(observations)
        self.reward=ne.evaluate(self.reward_formula_string, local_dict=self.rewarddict)
        if toreturn==True:
            return self.reward
    def init_action_space(self):
        self.action_space=spaces.Dict({})
        for key, val in self.out_dict.items():
            spacesnames=val['names']
            spacesrange=val['value_range']
            for spacename in spacesnames:
                self.action_space.spaces[spacename]=spaces.Box(low=spacesrange['min'], high=spacesrange['max'], shape=(1,), dtype=int)
    def init_observation_space(self):
        self.observation_space=spaces.Dict({})
        self.observation_space['raw_data']=spaces.Box(low=0, high=self.max_sfsystem_output, shape=(self.n_timepoints_per_sample, self.n_channels_of_interest), dtype=int) #n timepoints per sample rows, n input channels columns, signals should be normalized
        if self.record_fft:
            self.observation_space['fft']=spaces.Box(low=0.0, high=1.0, shape=(self.n_channels_of_interest, self.n_fft_values))
        if self.record_fbins:
            self.observation_space['fbins']=spaces.Box(low=0.0, high=1.0, shape=(self.n_channels_of_interest, self.n_fbins))
    def set_fft_params(self):
        self.sampling_frequency=int(1000/self.delay)
        self.max_possible_fft_frequency=self.sampling_frequency/2
        self.sampling_period=int(self.delay*self.n_timepoints_per_sample)
        self.timesteps=np.linspace(0, (self.n_timepoints_per_sample-1)*self.delay, self.n_timepoints_per_sample)
        self.fstep=self.sampling_frequency/self.n_timepoints_per_sample
        self.f=np.linspace(0, (self.n_timepoints_per_sample-1)*self.fstep, self.n_timepoints_per_sample)
        self.f_plot=self.f[0:int(self.n_timepoints_per_sample/2 + 1)]
        self.n_fft_values=len(self.f_plot)

    def connect(self):
        websocket.enableTrace(True)
        self.ws=websocket.WebSocket()
        self.ws.connect(self.device_address)
        self.connection_status=self.ws.recv()
    
    def get_fft_fromsignal(self, raw_singlech):
        X=np.fft.fft(raw_singlech)
        X_mag=np.abs(X)/self.n_timepoints_per_sample
        X_mag_plot=2*X_mag[0:int(self.n_timepoints_per_sample/2 + 1)]
        return X_mag_plot

    def get_fft_allchannels(self, raw_data=None):
        if raw_data == None:
            raw_data=self.observation_space.sample()['raw_data']

        fft_data=[]
        for chindex in range(raw_data.shape[1]):
            chraw=raw_data[:,chindex]
            chfft=self.get_fft_fromsignal(chraw)
            fft_data.append(chfft)
        fft_data=np.array(fft_data)
        return fft_data
    
    def get_bin_values_from_signal(self, fft_signlech):
        fpl=np.array(self.f_plot)
        xmp=np.array(fft_signlech)
        magnitudes=[]
        for low, high in self.fbins:
            mask = (fpl >= low) & (fpl < high)
            magnitude = np.abs(xmp[mask]).mean() #here can be other functions
            magnitudes.append(magnitude)
        magnitudes=np.array(magnitudes)
        return magnitudes

    def get_bin_values_allchannels(self, fft=None):
        if fft == None:
            fft=self.observation_space.sample()['fft']
        fbins_data=[]
        for chindex in range(fft.shape[0]):
            chfft=fft[chindex,:]
            chbins=self.get_bin_values_from_signal(chfft)
            fbins_data.append(chbins)
        fbins_data=np.array(fbins_data)
        return fbins_data

    def set_delay_and_data_transfer_buffer_size(self):
        self.ws.send("set_delay_and_data_transfer_buffer_size")
        time.sleep(self.timesleep_period)
        setup=False
        while setup==False:
            try:
                device_msg=self.ws.recv()
                if device_msg == "Awaiting delay and data transfer buffer size in shape with space separator":
                    self.ws.send(f'{self.delay},{self.n_timepoints_per_sample}')
                time.sleep(self.timesleep_period)
                device_msg=self.ws.recv()
                if device_msg == "Delay and data transfer buffer size set up":
                    setup = True 
                    break;
            except:
                pass
    def start_data_transfer_from_device(self):
        self.ws.send("start_data_transfer_from_ads")
    def stop_data_transfer_from_device(self):
        self.ws.send("stop_data_transfer_from_ads")
    def stop_audiovis_feedback(self):
        self.ws.send("stop_led_cycle")
    def update_audiovis_feedback(self, update_dict):
        self.ws.send("receive_output_control_data")
        outmsg_vals=[]
        for controlnm in self.out_order:
            outmsg_vals.append(update_dict[controlnm])
        msg=','.join(outmsg_vals)
        self.ws.send(msg)
    def step(self, action):
        # Implement the logic of the environment here
        # This method should return the new observation, the reward, whether the episode is done, and any additional info
        # ...

        return #observation, reward, done, info

    def reset(self):
        # Reset the environment to its initial state
        # This method should return the initial observation
        # ...

        return initial_observation

    def render(self, mode='human'):
        # Implement any rendering logic here
        # This method is optional and is not required for training the agent
        # ...

        pass

    def close(self):
        # Implement any cleanup logic here
        # This method is optional and is not required for training the agent
        # ...

        pass

In [13]:
communicator=SFSystemCommunicator()

--- request header ---
GET / HTTP/1.1
Upgrade: websocket
Host: 10.42.0.231
Origin: http://10.42.0.231
Sec-WebSocket-Key: XiEGEltwAPJtwfwv6coiHQ==
Sec-WebSocket-Version: 13
Connection: Upgrade


-----------------------
--- response header ---
HTTP/1.1 101 Switching Protocols
Server: arduino-WebSocketsServer
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Version: 13
Sec-WebSocket-Accept: a3V+0982CG9bOcZa67m2QLdIhlQ=
-----------------------
++Rcv raw: b'\x89\x00'
++Rcv decoded: fin=1 opcode=9 data=b''
++Sent raw: b'\x8a\x80\xe9\x98\x19\xe4'
++Sent decoded: fin=1 opcode=10 data=b''
++Rcv raw: b'\x81\tConnected'
++Rcv decoded: fin=1 opcode=1 data=b'Connected'


Token freq_30_ch0:
Closest fft frequency 30.0
Connected


In [17]:
sample=communicator.observation_space.sample()
sample['raw_data'].shape[1]

8

In [18]:
sample['raw_data']

array([[ 633,  909,  253,  174,  711,  192,  812,  991],
       [ 355,  144,  985,   83,  882,  887,  214,   84],
       [ 614,  687,  346,  308,  653,  238,  810,  591],
       [ 883,  862,  456,   30,  968,    2,  301,  574],
       [ 616,  797,  605,  168,  826,   42,  486,  555],
       [ 254,  545,  324,  660,  796,  887,  792,  455],
       [ 949,  350,  269,  885,  498,  935,  178,  762],
       [ 376,  157,  471,  169,  565,  282,  429,  358],
       [ 942,  278,  955,  781,  112,  950,  288,  589],
       [ 126,  665,   77,  906,  229,  528,  848,   83],
       [ 275,  446,  179,  504,  738,  536,  249,  266],
       [ 798,  994,  226,  201,   13,  775,  836,  602],
       [ 670,  360,   60,  780,  707,  165,  380,  856],
       [ 243,  481,  710,   81,  894,  381,  544,  783],
       [ 462,  197,  161,  229,    1,  735,  381,  855],
       [ 808,  282,  185,  338,  630,  914,  589,  181],
       [ 640,   23,  489,  776,  799,  502,  452,  791],
       [  79,   86,  299,  779,

In [14]:
communicator.rewarddict

{'fbin_05_4_ch0': None, 'freq_30_ch0': None, 'fbin_12_30_ch0': None}

In [15]:
communicator.get_reward()

In [16]:
communicator.reward

array(4.1762214, dtype=float32)

In [51]:
communicator.rewarddict

{'fbin_0.5_5_ch0': None, 'freq_30_ch0': None, 'fbin_10_20_ch0': None}

In [12]:
s = "(freq_50+fbin_0_10)/(fbin_20_30)"

# Define the variables
variables = {
    'freq_50': 10,
    'fbin_0_10': 20,
    'fbin_20_30':40
}

# Evaluate the string as a mathematical expression
result = ne.evaluate(s, local_dict=variables)
result

array(0.75)

In [5]:
communicator.action_space.sample()

OrderedDict([('leddelay', array([2577])),
             ('lv1r', array([251])),
             ('lv1g', array([91])),
             ('lv1b', array([237])),
             ('lv2r', array([103])),
             ('lv2g', array([178])),
             ('lv2b', array([80])),
             ('lv3r', array([220])),
             ('lv3g', array([71])),
             ('lv3b', array([167])),
             ('lv4r', array([192])),
             ('lv4g', array([133])),
             ('lv4b', array([10])),
             ('lv5r', array([164])),
             ('lv5g', array([127])),
             ('lv5b', array([83])),
             ('lv6r', array([136])),
             ('lv6g', array([143])),
             ('lv6b', array([20])),
             ('lv7r', array([98])),
             ('lv7g', array([81])),
             ('lv7b', array([109])),
             ('lv8r', array([163])),
             ('lv8g', array([235])),
             ('lv8b', array([76])),
             ('wave_1_freq', array([22009])),
             ('wave_2_freq', arra

In [3]:
communicator.print_reward_formula_legend()

Reward formula can use the following operators: //, *, **, -, +
It can refer to channels using "ch" prefix followed by an index (starting with 0) e.g. ch0
It can refer to values of frequency bins in specific channels e.g. fbin_10-50|ch0
The corresponding freqency bins must be present among fbins passed at the initialization step
fbins should be passed in the form of [(b1 min, b1max),...(bn min, bn max)]
It can refer to specific frequencies from fft e.g. freq_50|ch0
Some examples:
1. (freq_50|ch0+fbin_0-10|ch0)/(fbin_20-30|ch0)
2. freq_5|ch0/freq_10|ch0
3. fbin_0.5-5|ch0


In [4]:
communicator.connection_status

'Connected'

In [None]:



leddelay=100
ledvalues=[30, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 255, 0, 255, 255, 255, 0, 255, 255, 255, 0, 255, 0, 0]
sdict={'wave_1_freq':800, 'wave_2_freq':200, 'panner_freq':30, 'panner_div':2, 'phasor_1_freq':100, 'phasor_1_min':30, 'phasor_1_max':30, 
       'phasor_2_freq':30, 'phasor_2_min':30, 'phasor_2_max':30,
       'maxivolume':5,'wave_1_type':3,'wave_2_type':0}


In [5]:
ws.recv()

++Rcv raw: b'\x89\x00'
++Rcv decoded: fin=1 opcode=9 data=b''
++Sent raw: b'\x8a\x80\x1bE\xf6f'
++Sent decoded: fin=1 opcode=10 data=b''
++Rcv raw: b'\x81\tConnected'
++Rcv decoded: fin=1 opcode=1 data=b'Connected'


'Connected'

In [48]:
ws.send("receive_output_control_data")

++Sent raw: b'\x81\x9b]6\x89\xe2/S\xea\x874@\xec\xbd2C\xfd\x92(B\xd6\x812X\xfd\x902Z\xd6\x86<B\xe8'
++Sent decoded: fin=1 opcode=1 data=b'receive_output_control_data'


33

In [49]:
# 0. noise
# 1. sinewave
# 2. square
# 3. triangle


leddelay=100
ledvalues=[30, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 255, 0, 255, 255, 255, 0, 255, 255, 255, 0, 255, 0, 0]
sdict={'wave_1_freq':800, 'wave_2_freq':200, 'panner_freq':30, 'panner_div':2, 'phasor_1_freq':100, 'phasor_1_min':30, 'phasor_1_max':30, 
       'phasor_2_freq':30, 'phasor_2_min':30, 'phasor_2_max':30,
       'maxivolume':5,'wave_1_type':3,'wave_2_type':0}

In [50]:
message=','.join(list(map(str,ledvalues)))+f',{leddelay}'+f",{sdict['wave_1_freq']},{sdict['wave_2_freq']},{sdict['panner_freq']},{sdict['panner_div']},{sdict['phasor_1_freq']},{sdict['phasor_1_min']},{sdict['phasor_1_max']},{sdict['phasor_2_freq']},{sdict['phasor_2_min']},{sdict['phasor_2_max']},{sdict['maxivolume']},{sdict['wave_1_type']},{sdict['wave_2_type']}"

In [51]:
message

'30,0,0,0,255,0,0,0,255,255,255,255,0,255,255,255,0,255,255,255,0,255,0,0,100,800,200,30,2,100,30,30,30,30,30,5,3,0'

In [52]:
len(message.split(','))

38

In [53]:
nsecs=60
nmessages=int(nsecs*(1000/leddelay))
for m in range(2):
    ws.send(message)

++Sent raw: b'\x81\xf2~r\xdcqMB\xf0ARB\xf0AR@\xe9DRB\xf0ARB\xf0CKG\xf0CKG\xf0CKG\xf0CKG\xf0AR@\xe9DR@\xe9DR@\xe9DRB\xf0CKG\xf0CKG\xf0CKG\xf0AR@\xe9DRB\xf0ARC\xecARJ\xecAR@\xecARA\xec]L^\xedAN^\xefARA\xec]MB\xf0BN^\xefARG\xf0BRB'
++Sent decoded: fin=1 opcode=1 data=b'30,0,0,0,255,0,0,0,255,255,255,255,0,255,255,255,0,255,255,255,0,255,0,0,100,800,200,30,2,100,30,30,30,30,30,5,3,0'
++Sent raw: b"\x81\xf2'1U\x84\x14\x01y\xb4\x0b\x01y\xb4\x0b\x03`\xb1\x0b\x01y\xb4\x0b\x01y\xb6\x12\x04y\xb6\x12\x04y\xb6\x12\x04y\xb6\x12\x04y\xb4\x0b\x03`\xb1\x0b\x03`\xb1\x0b\x03`\xb1\x0b\x01y\xb6\x12\x04y\xb6\x12\x04y\xb6\x12\x04y\xb4\x0b\x03`\xb1\x0b\x01y\xb4\x0b\x00e\xb4\x0b\te\xb4\x0b\x03e\xb4\x0b\x02e\xa8\x15\x1dd\xb4\x17\x1df\xb4\x0b\x02e\xa8\x14\x01y\xb7\x17\x1df\xb4\x0b\x04y\xb7\x0b\x01"
++Sent decoded: fin=1 opcode=1 data=b'30,0,0,0,255,0,0,0,255,255,255,255,0,255,255,255,0,255,255,255,0,255,0,0,100,800,200,30,2,100,30,30,30,30,30,5,3,0'


In [54]:
ws.send('stop_led_cycle')

++Sent raw: b'\x81\x8e\x8c\xd2\x1c\xfd\xff\xa6s\x8d\xd3\xbey\x99\xd3\xb1e\x9e\xe0\xb7'
++Sent decoded: fin=1 opcode=1 data=b'stop_led_cycle'


20

In [263]:
len("%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d".split(','))

25