# Head

In [197]:
import numpy as np
import pandas as pd
import time
import serial
from matplotlib import pyplot as plt
from tqdm import tqdm
from os import listdir
%matplotlib widget

In [198]:
# from tensorflow.keras.utils import to_categorical
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler, LabelEncoder, PolynomialFeatures
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import learning_curve, StratifiedKFold, LearningCurveDisplay
from sklearn.metrics import classification_report, f1_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.pipeline import Pipeline, FunctionTransformer
from sklearn.compose import make_column_transformer

from IPython.display import clear_output

In [199]:
palms_list = ['./data/'+s for s in listdir('./data/') if s.__len__() == 24]
palm_file = palms_list[10]

model_name = None # LogisticRegression
model = LogisticRegression(max_iter=5000,class_weight={0: 1, 1: 1, 2: 1, 3: 10, 4: 1, 5: 1}) # 
# comet.ml data
API_KEY = ''
PROJECT_NAME = ''
WORKSPACE = ''
# do list
DO_DRAW_PLOTS = False # Построение графиков
DO_HYPEROPT = False # Поиск гиперпараметров hyperopt
DO_POLYNOMIAL_FEATURES = False # Полиномиальные признаки
DO_MMSCALER = True # Скалирование данных
DO_PLOT_LEARNING_CURVE = False # Нарисовать график кривой обучения
DO_OFFLINE_INFERENCE = False # Оффлайн проверка инференса
DO_ONLINE_INFERENCE = False # Онлайн проверка инференса
DO_LOG = False # Логирование 

In [246]:
try: print_report() 
except: pass

              precision    recall  f1-score   support

     Neutral       0.98      0.96      0.97     10615
        Open       0.98      0.94      0.96       987
      Pistol       0.95      0.91      0.93       933
       Thumb       0.71      0.99      0.82      1011
          OK       0.97      0.85      0.91      1057
        Grab       0.99      1.00      0.99      1076

    accuracy                           0.95     15679
   macro avg       0.93      0.94      0.93     15679
weighted avg       0.96      0.95      0.96     15679

              precision    recall  f1-score   support

     Neutral       0.97      0.92      0.94      2593
        Open       0.96      0.90      0.93       251
      Pistol       0.88      0.96      0.92       277
       Thumb       0.65      0.84      0.73       262
          OK       0.71      0.95      0.81       239
        Grab       0.97      0.93      0.95       267

    accuracy                           0.92      3889
   macro avg       0.86

# IO utils


In [201]:
def read_omg_csv(path_palm_data: str, 
                 n_omg_channels: int, 
                 n_acc_channels: int = 0, 
                 n_gyr_channels: int = 0, 
                 n_mag_channels: int = 0, 
                 n_enc_channels: int = 0,
                 button_ch: bool = True, 
                 sync_ch: bool = True, 
                 timestamp_ch: bool = True) -> pd.DataFrame:
    
    '''
    Reads CSV data for OMG data
    NB: data must be separated by " " separator

        Parameters:
                path_palm_data  (str): path to csv data file
                n_omg_channels  (int): Number of OMG channels
                n_acc_channels  (int): Number of Accelerometer channels, default = 0
                n_gyr_channels  (int): Number of Gyroscope channels, default = 0
                n_mag_channels  (int): Number of Magnetometer channels, default = 0
                n_enc_channels  (int): Number of Encoder channels, default = 0
                button_ch      (bool): If button channel is present, default = True
                sync_ch        (bool): If synchronization channel is present, default = True
                timestamp_ch   (bool): If timestamp channel is present, default = True

        Returns:
                df_raw (pd.DataFrame): Parsed pandas Dataframe with OMG data
    '''
    
    df_raw = pd.read_csv(path_palm_data, sep=' ', 
                         header=None, 
                         skipfooter=1, 
                         skiprows=1, 
                         engine='python')
    columns = np.arange(n_omg_channels).astype('str').tolist()
    
    for label, label_count in zip(['ACC', 'GYR', 'MAG', 'ENC'], 
                                  [n_acc_channels, n_gyr_channels, n_mag_channels, n_enc_channels]):
        columns = columns + ['{}{}'.format(label, i) for i in range(label_count)]
        
    if button_ch:
        columns = columns + ['BUTTON']
        
    if sync_ch:
        columns = columns + ['SYNC']
        
    if timestamp_ch:
        columns = columns + ['ts']
        
    df_raw.columns = columns
    
    return df_raw

In [202]:
def print_report():
    y_pred_train = model.predict(X_train)
    y_pred_test = model.predict(X_test)
    print('='*25,'Train','='*25)
    print(classification_report(y_train, y_pred_train, target_names=GESTURES))
    print('='*25,'Test','='*25)
    print(classification_report(y_test, y_pred_test, target_names=GESTURES))

In [203]:
# Функция для рисования матрицы ошибок:
def draw_confusion_matrix():
    predictions = model.predict(X_test)
    cm = confusion_matrix(y_test, predictions, labels=model.classes_)
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm,
        display_labels=model.classes_
    )
    disp.plot();

# Abbreviations
OMG — optomiography  
ACC — accelerometer  
GYR — gyroscope  
ENC - encoders (fingers of prosthesis or gloves)  
model — model of data  

# Data

In [204]:
gestures = read_omg_csv(palm_file, 
                              n_omg_channels=50,
                              n_acc_channels=3, 
                              n_gyr_channels=3, 
                              n_enc_channels=6, 
                              n_mag_channels=0)

print(gestures.shape) 
gestures.head()

(19568, 65)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,GYR2,ENC0,ENC1,ENC2,ENC3,ENC4,ENC5,BUTTON,SYNC,ts
0,12,8,5,6,2,4,5,7,14,8,...,-13,0,0,0,0,0,0,0,0,555777
1,13,8,4,5,3,0,5,4,9,10,...,-16,0,0,0,0,0,0,0,0,555810
2,12,5,9,5,0,0,6,5,10,9,...,-14,0,0,0,0,0,0,0,0,555843
3,10,7,6,4,1,0,4,6,7,8,...,-12,0,0,0,0,0,0,0,0,555876
4,12,6,6,7,3,2,6,7,9,12,...,-12,0,0,0,0,0,0,0,0,555909


In [205]:
# [X features] '0', ..., '49' - каналы OMG датчиков
#              'ACC0', 'ACC1', 'ACC2' - акселерометр (потенциально могут использоваться как факторы для модели)
#              'GYR0', 'GYR1', 'GYR2' - гироскоп     (потенциально могут использоваться как факторы для модели)
# 'BUTTON' - не используется
# 'SYNC' - синхронизация данных с протоколом
# 'ts' - метка времени
gestures.columns

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
       '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24',
       '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36',
       '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48',
       '49', 'ACC0', 'ACC1', 'ACC2', 'GYR0', 'GYR1', 'GYR2', 'ENC0', 'ENC1',
       'ENC2', 'ENC3', 'ENC4', 'ENC5', 'BUTTON', 'SYNC', 'ts'],
      dtype='object')

In [206]:
# Списки с названиями признаков
OMG_CH = [str(i) for i in range(50)]

ACC_CH = ['ACC0', 'ACC1', 'ACC2']
GYR_CH = ['GYR0', 'GYR1', 'GYR2']

ENC_CH = ['ENC0', 'ENC1', 'ENC2', 'ENC3', 'ENC4', 'ENC5']
BUTTON_SYNC_TS_CH = ['BUTTON', 'SYNC', 'ts']
# Проверка что списки нужной длины
assert len(OMG_CH)+len(ACC_CH)+len(GYR_CH)+len(ENC_CH)+len(BUTTON_SYNC_TS_CH) == gestures.shape[-1]

ALL_CH = OMG_CH + ACC_CH + GYR_CH

print(f"OMG_CH: {list(OMG_CH)}")
print(f"ACC_CH: {list(ACC_CH)}")
print(f"GYR_CH: {list(GYR_CH)}")
print(f"ENC_CH: {list(ENC_CH)}")
print(f"BUTTON_SYNC_TS_CH: {list(BUTTON_SYNC_TS_CH)}")

OMG_CH: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49']
ACC_CH: ['ACC0', 'ACC1', 'ACC2']
GYR_CH: ['GYR0', 'GYR1', 'GYR2']
ENC_CH: ['ENC0', 'ENC1', 'ENC2', 'ENC3', 'ENC4', 'ENC5']
BUTTON_SYNC_TS_CH: ['BUTTON', 'SYNC', 'ts']


### OMG_CH

In [207]:
if DO_DRAW_PLOTS:
    # График показаний датчиков
    fig = plt.figure(figsize=(10, 4))
    plt.plot(gestures[OMG_CH].values)
    plt.title('OMG')
    plt.xlabel('Timesteps')
    plt.tight_layout()

### ACC GYR

In [208]:
if DO_DRAW_PLOTS:
    # График показаний акселерометров
    fig = plt.figure(figsize=(10, 4))
    plt.plot(gestures[ACC_CH].values)
    plt.title('ACC')
    plt.xlabel('Timesteps')
    plt.tight_layout()

In [209]:
if DO_DRAW_PLOTS:
    # График показаний гироскопов
    fig = plt.figure(figsize=(10, 4))
    plt.plot(gestures[GYR_CH].values)
    plt.title('GYR')
    plt.xlabel('Timesteps')
    plt.tight_layout()

### Protocol

In [210]:
# Целевые признаки сгибаний и разгибаний пальцев руки
gestures_protocol = pd.read_csv(f'{palm_file}.protocol.csv', index_col=0)
gestures_protocol

Unnamed: 0_level_0,Thumb,Index,Middle,Ring,Pinky,Thumb_stretch,Index_stretch,Middle_stretch,Ring_stretch,Pinky_stretch,Pronation
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
176,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
177,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0
178,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0
179,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [211]:
# Перекодируем 10 признаков в 1 целевой признак 
le = LabelEncoder()

# FIT
le.fit(
    gestures_protocol[[
        #'Pronation',
        "Thumb","Index","Middle","Ring","Pinky",
        'Thumb_stretch','Index_stretch','Middle_stretch','Ring_stretch','Pinky_stretch',
    ]]
    .apply(lambda row: str(tuple(row)), axis=1)
)
# TRANSFORM
gestures_protocol['gesture'] = le.transform(
    gestures_protocol[[
        #'Pronation',
        "Thumb","Index","Middle","Ring","Pinky",
        'Thumb_stretch','Index_stretch','Middle_stretch','Ring_stretch','Pinky_stretch',
    ]]
    .apply(lambda row: str(tuple(row)), axis=1)
)
# Оказалось только 6 уникальных комбинаций сгибаний и растягиваний пальцев
display(np.c_[le.transform(le.classes_), le.classes_])

array([[0, '(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)'],
       [1, '(0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0)'],
       [2, '(0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0)'],
       [3, '(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)'],
       [4, '(1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)'],
       [5, '(1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0)']],
      dtype=object)

In [212]:
GESTURES = ['Neutral', 'Open', 'Pistol', 'Thumb', 'OK', 'Grab']
GESTURES

['Neutral', 'Open', 'Pistol', 'Thumb', 'OK', 'Grab']

In [213]:
# Жесты в разные эпохи, 1 эпоха = 1 жест
gestures_protocol['gesture']

epoch
0      0
1      0
2      3
3      5
4      1
      ..
176    3
177    5
178    1
179    4
180    2
Name: gesture, Length: 181, dtype: int32

In [214]:
# Расширенный целевой признак
y_cmd = np.array([gestures_protocol['gesture'].loc[s] for s in gestures['SYNC'].values])
print(y_cmd.shape)
y_cmd

(19568,)


array([0, 0, 0, ..., 0, 0, 0])

In [215]:
# Расширенный признак поворота руки
y_pronation = np.array([gestures_protocol['Pronation'].loc[s] for s in gestures['SYNC'].values])
print(y_pronation.shape)
y_pronation

(19568,)


array([0., 0., 0., ..., 0., 0., 0.])

# Data preprocessing

In [216]:
gestures = gestures[ALL_CH]

## OMG_CH

### Features generation

In [217]:
#pd.concat([gestures[OMG_CH].iloc[:,[0]].shift(i).rename(lambda x:str(x)+f'_shift_{i}',axis=1) for i in range(0,4)],axis=1)

In [218]:
def generate_features(df, columns, windows=[5,15,25], lags=range(1,10)):
    methods = {
        'moving_average': {'func': lambda x, window: x.rolling(window).mean(),'window':True},
        'derivative': {'func': lambda x: x.diff(),'window':False},
        'moving_std': {'func': lambda x, window: x.rolling(window).std(),'window':True},
        'moving_min': {'func': lambda x, window: x.rolling(window).min(),'window':True},
        'moving_max': {'func': lambda x, window: x.rolling(window).max(),'window':True},
        'lag_diff': {'func': lambda x: features_df[f'{column}_lag_{lags[0]}'] - features_df[f'{column}_lag_{lags[-1]}'],'window':False}
        #'pct_change': {'func': lambda x: x.pct_change()},
    }
    
    features_df = pd.DataFrame(index=df.index)
    
    for column in tqdm(columns):
        for lag in lags: # Создаем признаки лагов
            features_df = pd.concat([ features_df, df[column].shift(lag).rename(f'{column}_lag_{lag}')],axis=1)
            
        for method in methods:
            method_info = methods[method]
            if method_info['window']:
                for window in windows:
                    features_df = pd.concat([ features_df, method_info['func'](df[column],window).rename(f'{column}_{method}_w{window}') ],axis=1)
            else:
                features_df = pd.concat([ features_df, method_info['func'](df[column]).rename(f'{column}_{method}') ],axis=1)
    
    features_df = features_df.bfill().ffill()
    features_df = features_df.replace([np.inf,-np.inf],0)
    
    return features_df

new_features = generate_features(gestures,OMG_CH)
gestures = pd.concat([gestures,new_features],axis=1)
gestures

100%|██████████| 50/50 [00:17<00:00,  2.79it/s]


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,49_moving_std_w5,49_moving_std_w15,49_moving_std_w25,49_moving_min_w5,49_moving_min_w15,49_moving_min_w25,49_moving_max_w5,49_moving_max_w15,49_moving_max_w25,49_lag_diff
0,12,8,5,6,2,4,5,7,14,8,...,1.303840,2.614975,2.291288,27.0,20.0,20.0,30.0,30.0,30.0,0.0
1,13,8,4,5,3,0,5,4,9,10,...,1.303840,2.614975,2.291288,27.0,20.0,20.0,30.0,30.0,30.0,0.0
2,12,5,9,5,0,0,6,5,10,9,...,1.303840,2.614975,2.291288,27.0,20.0,20.0,30.0,30.0,30.0,0.0
3,10,7,6,4,1,0,4,6,7,8,...,1.303840,2.614975,2.291288,27.0,20.0,20.0,30.0,30.0,30.0,0.0
4,12,6,6,7,3,2,6,7,9,12,...,1.303840,2.614975,2.291288,27.0,20.0,20.0,30.0,30.0,30.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19563,3,6,7,4,2,0,6,3,13,8,...,3.049590,2.414243,2.244994,24.0,22.0,22.0,31.0,31.0,31.0,2.0
19564,26,6,5,3,2,1,4,4,9,7,...,2.738613,2.386470,2.244994,24.0,22.0,22.0,31.0,31.0,31.0,0.0
19565,14,7,6,4,2,0,6,5,13,8,...,1.816590,2.350279,2.244994,27.0,22.0,22.0,31.0,31.0,31.0,5.0
19566,30,6,6,4,2,0,5,5,10,6,...,1.303840,2.350279,2.158703,27.0,22.0,22.0,30.0,31.0,31.0,3.0


### Features dropout

In [219]:
def dropout_features(df, target='command', weak_edge=0.02, strong_edge=0.95):
    df[target] = y_cmd
    corr_df = df.corr().abs()
    
    # Отсев по сильной корреляции признаков между собой
    upper = corr_df.where(np.triu(np.ones(corr_df.shape), k=1).astype(bool))
    drop_strong = [column for column in upper.columns if any(upper[column] > strong_edge) and column != target]
    
    # Отсев по слабой корреляции признаков с таргетом
    corr_df = corr_df[target]
    drop_weak = list(corr_df[corr_df <= weak_edge].index)
    
    return list(set(drop_weak+drop_strong))

target = 'command'
to_drop = dropout_features(gestures,target=target)
print(f'Droped features: {to_drop.__len__()}')
gestures = gestures.drop(to_drop+[target], axis=1)
print(f'Remained features: {gestures.columns.__len__()}')

Droped features: 916
Remained features: 290


In [220]:
ALL_COLS = [*gestures.columns]
OMG_CH = [*gestures.filter(gestures.columns.map(lambda s: s if s.isdigit() else None),axis=1).columns]
OMG_DERIV = [*gestures.filter(gestures.columns.map(lambda s: s if 'derivative' in s else None),axis=1).columns]
ACC_CH = [*gestures.filter(gestures.columns.map(lambda s: s if 'ACC' in s else None),axis=1).columns]
GYR_CH = [*gestures.filter(gestures.columns.map(lambda s: s if 'GYR' in s else None),axis=1).columns]
PCT_CH = [*gestures.filter(gestures.columns.map(lambda s: s if 'pct_change' in s else None),axis=1).columns]

## Protocol

In [222]:
# Смещение целевого признака по скачку производных
id_max = 0
cur_gesture = 0
for i in range(y_cmd.shape[0]):
    if i < id_max: # Пропускаем все значения до id_max
        continue
    prev_gesture = cur_gesture # предыдущий жест
    cur_gesture = y_cmd[i] # текущий жест
    if cur_gesture != prev_gesture: # Если сменился жест то ищем в данных спереди максимальный скачок
        id_max = gestures[OMG_DERIV][i:i+35].abs().sum(axis=1).idxmax() # модуль производных -> сумма всех датчиков -> айди максимального значения
        y_cmd[i:id_max] = prev_gesture # Заменяем все значения до id_max на предыдущий жест

In [223]:
if DO_DRAW_PLOTS: # После смещения
    # График показаний датчиков
    fig, axx = plt.subplots(3, 1, sharex=True, figsize=(10, 5))
    
    plt.sca(axx[0])
    plt.plot(gestures[OMG_CH].values)
    plt.grid()
    plt.title('OMG_CH')
    
    plt.sca(axx[1])
    plt.plot(gestures[OMG_DERIV].values)
    plt.grid()
    plt.title('OMG_DERIV')
    
    plt.sca(axx[2])
    plt.plot(y_cmd)
    plt.grid()
    plt.yticks(np.arange(len(GESTURES)), GESTURES)
    plt.title('Command')
    plt.xlabel('Timesteps')
        
    plt.suptitle('OMG and Protocol')
    plt.tight_layout()

# Train-test split

In [224]:
df_meta = pd.read_csv('./data/meta_information.csv', index_col=0)
df_meta.head()

Unnamed: 0,montage,pilote_id,last_train_idx,len(train),len(test)
0,2023-05-15_16-16-08.palm,1,23337,23337,5810
1,2023-05-15_17-12-24.palm,1,23336,23336,5803
2,2023-06-05_16-12-38.palm,1,17939,17939,4431
3,2023-06-05_17-53-01.palm,1,17771,17771,4435
4,2023-06-20_14-43-11.palm,1,17936,17936,4441


In [225]:
# Узнаем последний индекс обучающей выборки для нужного файла
last_train_idx = df_meta[df_meta['montage'] == palm_file.split('/')[-1]].to_dict(orient='records')[0]['last_train_idx']
last_train_idx

15679

In [226]:
# Разделяем данные на обучающую и тестовую по индексу полученному с /data/meta_information.csv
X_train = gestures.values[:last_train_idx]
y_train = y_cmd[:last_train_idx]

X_test = gestures.values[last_train_idx:]
y_test = y_cmd[last_train_idx:]

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

(15679, 290) (15679,)
(3889, 290) (3889,)


# Modeling

In [227]:
if DO_HYPEROPT:
    # Hyperopt
    from hyperopt import hp, fmin, tpe, Trials
    from hyperopt.pyll.base import scope
    # зададим пространство поиска гиперпараметров
    solver = ['lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga']
    criterion = ['gini','entropy','log_loss']
    space={
           #'criterion': hp.choice('criterion',criterion),
           #'n_estimators': scope.int(hp.quniform('n_estimators', 1, 350,1)),
           #'max_depth': scope.int(hp.quniform('max_depth', 1, 35,1)),
           #'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 15,1)),
           #'C': hp.uniform('C', 0.01, 10),
           #'solver': hp.choice('solver',solver),
           #'max_iter': scope.int(hp.quniform('max_iter', 100, 10000, 100))
          }
    # функция которую оптимизируем
    def hyperopt_rf(params, X=X_train, y=y_train):
        model = model_name(**params) # Создаем модель с тестируемыми параметрами
        shp = y_train.shape
        model.fit(X_train[:int(shp[0]*0.75)],y_train[:int(shp[0]*0.75)])
        y_pred = model.predict(X_train[int(shp[0]*0.75):],y_train[int(shp[0]*0.75):])
        score = f1_score(y_train, y_pred, scoring="f1_micro")
        return -score
    # Подбор гиперпараметров Hyperopt
    trials = Trials() # используется для логирования результатов
    best=fmin(hyperopt_rf, # наша функция 
            space=space, # пространство гиперпараметров
            max_evals=10, # максимальное количество итераций
            timeout=180, # время выполнения
            trials=trials, # логирование результатов
            )
    # Приводим best_params в нужный вид
    for key,val in best.items():
        if key == 'criterion': # Если параметр criterion то берём значение из списка criterion
            best[key] = criterion[val]
        elif key == 'solver': # Если параметр solver то берём значение из списка solver
            best[key] = solver[val]
        elif val % 1 == 0: # Если число без цифр после запятой, но float, то меняем на int
            best[key] = int(val)
    print(f"Наилучшие значения гиперпараметров {best}")

In [228]:
if DO_POLYNOMIAL_FEATURES:
    poly = PolynomialFeatures(degree=2)
    X_train = poly.fit_transform(X_train)
    X_test = poly.transform(X_test)

In [229]:
if DO_MMSCALER:
    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

In [230]:
# Создание модели
if model_name != None:
    model = model_name(**best)
    
model.fit(X=X_train, y=y_train)

In [231]:
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

print(classification_report(y_train, y_pred_train))
print(classification_report(y_test, y_pred_test))

              precision    recall  f1-score   support

           0       0.98      0.96      0.97     10615
           1       0.98      0.94      0.96       987
           2       0.95      0.91      0.93       933
           3       0.71      0.99      0.82      1011
           4       0.97      0.85      0.91      1057
           5       0.99      1.00      0.99      1076

    accuracy                           0.95     15679
   macro avg       0.93      0.94      0.93     15679
weighted avg       0.96      0.95      0.96     15679

              precision    recall  f1-score   support

           0       0.97      0.92      0.94      2593
           1       0.96      0.90      0.93       251
           2       0.88      0.96      0.92       277
           3       0.65      0.84      0.73       262
           4       0.71      0.95      0.81       239
           5       0.97      0.93      0.95       267

    accuracy                           0.92      3889
   macro avg       0.86

In [232]:
# Функция построения графика кривой обучения
def plot_learning_curve(model, X, y, cv=4, scoring="f1_micro"):
    # Вычисляем координаты для построения кривой обучения
    skf = StratifiedKFold(n_splits=cv, shuffle=True) # кросс валидация
    train_sizes, train_scores, valid_scores = learning_curve(
        estimator=model,  # модель
        X=X,  # матрица наблюдений X
        y=y,  # вектор ответов y
        cv=skf, # кросс валидатор
        scoring=scoring,  # метрика
        train_sizes=np.linspace(0.05,1,10), # Кол-во разбиений
    )
    display = LearningCurveDisplay(train_sizes=train_sizes,train_scores=train_scores,test_scores=valid_scores,score_name=scoring)
    display.plot()
    plt.title('Learning Curve')
    plt.yticks(np.arange(0,1.05,0.05))
    plt.grid()
    plt.show()
    return display
if DO_PLOT_LEARNING_CURVE:
    learn_curve_plot = plot_learning_curve(model,X_train,y_train)

## Prediction

In [233]:
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

In [234]:
if DO_DRAW_PLOTS:
    # График истинных и предсказанных значений таргета в тренировочных данных
    fig = plt.figure(figsize=(10, 4))
    plt.plot(y_train,  c='C0', label='y_true')
    plt.plot(y_pred_train, c='C1', label='y_pred')
    
    plt.yticks(np.arange(len(GESTURES)), GESTURES)
    plt.grid()
    plt.xlabel('Timesteps')
    plt.legend()
    plt.title('Train')
    plt.tight_layout()

In [235]:
if DO_DRAW_PLOTS:
    # График истинных и предсказанных значений таргета в тестовых данных
    fig = plt.figure(figsize=(10, 4))
    plt.plot(y_test,  c='C0', label='y_true')
    plt.plot(y_pred_test, c='C1', label='y_pred')
    
    plt.yticks(np.arange(len(GESTURES)), GESTURES)
    plt.grid()
    plt.xlabel('Timesteps')
    plt.legend()
    plt.title('Test')
    plt.tight_layout()

## Metrics

In [236]:
print_report()

              precision    recall  f1-score   support

     Neutral       0.98      0.96      0.97     10615
        Open       0.98      0.94      0.96       987
      Pistol       0.95      0.91      0.93       933
       Thumb       0.71      0.99      0.82      1011
          OK       0.97      0.85      0.91      1057
        Grab       0.99      1.00      0.99      1076

    accuracy                           0.95     15679
   macro avg       0.93      0.94      0.93     15679
weighted avg       0.96      0.95      0.96     15679

              precision    recall  f1-score   support

     Neutral       0.97      0.92      0.94      2593
        Open       0.96      0.90      0.93       251
      Pistol       0.88      0.96      0.92       277
       Thumb       0.65      0.84      0.73       262
          OK       0.71      0.95      0.81       239
        Grab       0.97      0.93      0.95       267

    accuracy                           0.92      3889
   macro avg       0.86

# Real-time inference

In [237]:
crc_table = [
    54, 181, 83, 241, 89, 16, 164, 217, 34, 169, 220, 160, 11, 252,
    111, 241, 33, 70, 99, 240, 234, 215, 60, 206, 68, 126, 152, 81,
    113, 187, 14, 21, 164, 172, 251, 16, 248, 215, 236, 90, 49, 53,
    179, 156, 101, 55, 65, 130, 161, 22, 218, 79, 24, 168, 152, 205,
    115, 141, 23, 86, 141, 58, 122, 215, 252, 48, 69, 115, 138, 66,
    88, 37, 63, 104, 176, 46, 139, 246, 222, 184, 103, 92, 154, 174,
    97, 141, 195, 166, 227, 150, 140, 48, 121, 243, 13, 131, 210, 199,
    45, 75, 180, 104, 97, 82, 251, 90, 132, 111, 229, 175, 146, 216,
    153, 86, 166, 33, 184, 100, 225, 248, 186, 54, 89, 39, 2, 214, 2,
    114, 197, 6, 35, 188, 245, 64, 220, 37, 123, 132, 190, 60, 189, 53,
    215, 185, 238, 145, 99, 226, 79, 54, 102, 118, 210, 116, 51, 247,
    0, 191, 42, 45, 2, 132, 106, 52, 63, 159, 229, 157, 78, 165, 50,
    18, 108, 193, 166, 253, 3, 243, 126, 111, 199, 152, 36, 114, 147,
    57, 87, 14, 16, 160, 128, 97, 189, 51, 115, 142, 8, 70, 71, 55, 42,
    193, 65, 207, 122, 158, 26, 21, 72, 139, 33, 230, 230, 116, 134, 5,
    213, 165, 107, 41, 134, 219, 190, 26, 29, 136, 174, 190, 108, 185, 172,
    137, 239, 164, 208, 207, 206, 98, 207, 12, 0, 174, 64, 20, 90, 49, 12,
    67, 112, 109, 78, 114, 165, 244, 183, 121
];
    
def get_crc(data, count):
    result = 0x00
    l = count
    while (count>0):   
        result = crc_table[result ^ data[l-count]]
        count -= 1;
    return result;

def drv_abs_one(ser, control):
    if ser is None:
        return
    pack = (np.array([80, 0x01, 0xBB, 6] + [c for c in control] + [0], dtype='uint8'))
    pack[-1] = get_crc(pack[1:], 9)
    return (pack, ser.write(bytearray(pack)))

## Inference Utils: Put your code here

In [238]:
def preprocessing(x): # Предобработка входных данных
    # x - sample vector
    y = x
    return y

def inference(x): # Предикт модели на данных
    y = model.predict([x])
    return y

def postprocessing(x, prev): # Постобработка результатов модели, сглаживаем с предыдущим предиктом
    if prev is None:
        y = x
    else:
        y = x*0.1 + prev*0.9 # Holt-Winters filter
    return y

def commands(x):
    y = np.round(np.clip(x / 100, 0, 1)*100).astype(int)
    return y

## Inference

### Offline (dataset)

In [239]:
df_sim = gestures.iloc[last_train_idx:]
print(df_sim.shape)

(3889, 290)


In [240]:
if DO_OFFLINE_INFERENCE:
    TIMEOUT = 0.033
    DEBUG = False

    i = 0
    ts_old = time.time()
    ts_diff = 0;
    ts_list = []

    y_previous = None
    y_dct = {
        'omg_sample':[],
        'enc_sample':[],
        'sample_preprocessed':[],

        'y_predicted':[],
        'y_postprocessed':[],
        'y_commands':[],
    }
    while True:    

        # [Data reading]
        ts_start = time.time()

        try:
            # [Sim data]
            if i < len(df_sim):
                sample = df_sim.values[i]
            else:
                break
            # [/Sim data]
            [omg_sample, acc_sample, enc_sample, [button, sync, ts]] = np.array_split(sample, [ALL_CH.__len__()])

        except Exception as e:
            print(e)
        # [/Data Reading]

        # [Data preprocessing]
        sample_preprocessed = preprocessing(omg_sample)
        # [/Data preprocessing]

        # [Inference]
        y_predicted         = inference(sample_preprocessed)
        # [/Inference]

        # [Inference Postprocessing]
        y_postprocessed     = postprocessing(y_predicted, y_previous)
        # [/Inference Postprocessing]

        # [Commands composition]
        y_commands          = commands(y_postprocessed)
        # [/Commands composition]

        # [Commands sending]
        # NO COMMANDS SENDING IN SIMULATION
        # [/Commands sending]

        # [Data logging]
        y_dct['omg_sample'].append(omg_sample)
        y_dct['enc_sample'].append(enc_sample)
        y_dct['sample_preprocessed'].append(sample_preprocessed)
        y_dct['y_predicted'].append(y_predicted)
        y_dct['y_postprocessed'].append(y_postprocessed)
        y_dct['y_commands'].append(y_commands)
        # [/Data logging]

        y_previous = y_postprocessed

        if DEBUG:
            clear_output(wait=True)

            # sanity check: Sizes of SAMPLE=65, OMG=50, ACC=6, ENCODERS=6
            print(f'SAMPLE SIZE: {len(sample)}, OMG: {len(omg_sample)}, ACC: {len(acc_sample)}, ENCODERS: {len(enc_sample)}')
            print(f'BUTTON: {button}, SYNC: {sync}, TS: {ts}')
            print(y_commands)

        # Считаем время выполнение инференса
        ts_diff = time.time() - ts_start

        ts_list.append(ts_diff)
        ts_old = ts_start
        i += 1 

    ts_mean = np.array(ts_list).mean()
    ts_max = np.array(ts_list).max()
    if ts_max > TIMEOUT:
        print('Calculation cycle takes more than TIMEOUT')
    print(f'Timeout: {TIMEOUT}')
    print(f'   mean: {round(ts_mean,5)}')
    print(f'    max: {round(ts_max,5)}')

In [241]:
if DO_OFFLINE_INFERENCE:
    for key, val in y_dct.items():
        # print(f"len({key}) = {len(y_dct[key])}")
        y_dct[key] = np.stack(val)
        print(f"{key}.shape = {y_dct[key].shape}")

In [242]:
if DO_OFFLINE_INFERENCE and DO_DRAW_PLOTS:
    fig = plt.figure(figsize=(10, 3))
    
    plt.plot(y_test, c='C0', label='y_cmd')
    plt.plot(y_dct['y_predicted'], c='C1', label='y_predicted')
    plt.plot(y_dct['y_postprocessed'], c='C2', label='y_postprocessed')
    plt.plot(y_dct['y_commands'], c='C3', label='y_commands')
    
    plt.title('Ground truth vs predicted vs postprocessed vs commands')
    plt.yticks(np.arange(len(GESTURES)), GESTURES)
    plt.legend()
    plt.grid()
    plt.xlabel('Timesteps')
    
    plt.tight_layout()

In [243]:
if DO_OFFLINE_INFERENCE:
    print(classification_report(y_test, y_dct['y_commands'], target_names=GESTURES))

### Online (prosthesis or virtual hand)

In [244]:
if DO_ONLINE_INFERENCE:
    TIMEOUT = 0.033
    DEBUG = True
    
    ser = None
    # ser_port = None
    # ser_port = '/dev/ttyACM0'
    ser_port = '/dev/cu.usbmodem3498365F31351'
    
    if ser_port is not None:
        ser = serial.Serial(port=ser_port, timeout=2*TIMEOUT)
        ser.write('T1#\r\n'.encode('utf-8')) # T1 for Timestamp activate
        ser.write('M2#\r\n'.encode('utf-8')) # M2 for Mode == 2 = send samples
        ser.write('S2#\r\n'.encode('utf-8')) # SYNC to 2 for sanity check
    
        # flush buffers
        ser.reset_input_buffer()
        ser.read()
        
        i = 0;
        while(ser.in_waiting):
            print(f'Flushing buffers {i}: {ser.in_waiting}', end='    \r')
            ser.read_all()
            time.sleep(0.005)
            i+=1;
        ser.readline()
        ser.readline()
        
    i = 0
    ts_old = time.time()
    ts_diff = 0;
    
    y_previous = None
    while True:    
        
        # [Data reading]
        s = ser.readline()
        ts_start = time.time()
        
        try:
            sample = np.array(s.decode('UTF-8')\
                               .replace('\r\n', "")\
                               .split(' ')
                             ).astype(int)
            [omg_sample, acc_sample, enc_sample, [button, sync, ts]] = np.array_split(sample, [50, 56, 62])
            
        except Exception as e:
            print(e)
            
        # [/Data Reading]
            
        # [Data preprocessing]
        sample_preprocessed = preprocessing(omg_sample)
        # [/Data preprocessing]
        
        # [Inference]
        y_predicted         = inference(sample_preprocessed)
        # [/Inference]
        
        # [Inference Postprocessing]
        y_postprocessed     = postprocessing(y_predicted, y_previous)
        # [/Inference Postprocessing]
        
        # [Commands composition]
        y_commands          = commands(y_postprocessed)
        # [/Commands composition]
        
        # [Commands sending]
        pack, _ = drv_abs_one(ser, list(y_commands)); # + [0]
        # [/Commands sending]
        
        y_previous = y_postprocessed
        
        if DEBUG:
            clear_output(wait=True)
    
            # sanity check: iteration should increase monotonically, TIMEDIFF approximately 32-34 ms, CYCLETIME < TIMEOUT, WAITING should be == 0
            print(f'ITERATION:\t{i}\tTIMEDIFF:\t{(ts_start - ts_old)*1000: .0f}\tCYCLETIME:\t{ts_diff*1000:.0f}\tWAITING:\t{ser.in_waiting}')
            print('INPUT:\n', s)
    
            # sanity check: Sizes of SAMPLE=65, OMG=50, ACC=6, ENCODERS=6
            print(f'SAMPLE SIZE: {len(sample)}, OMG: {len(omg_sample)}, ACC: {len(acc_sample)}, ENCODERS: {len(enc_sample)}')
            print(f'BUTTON: {button}, SYNC: {sync}, TS: {ts}')
            print(y_commands)
            print(pack)
        
        ts_diff = time.time() - ts_start
        assert(ts_diff<TIMEOUT), 'Calculation cycle takes more than TIMEOUT, halting...'
        ts_old = ts_start
        i += 1 

# Log into comet.ml

In [245]:
if DO_LOG:
    from comet_ml import Experiment
    # Создаем эксперимент у пользователя WORKSPACE в проекте PROJECT_NAME используя ключ API_KEY
    exp = Experiment(
        api_key=API_KEY,
        project_name=PROJECT_NAME,
        workspace=WORKSPACE,
    )
    try: exp.set_name(type(model).__name__)
    except:pass
    
    # Модель
    try: exp.log_text(text=type(model).__name__)
    except:pass
    # Параметеры
    try: exp.log_parameters(model.get_params())
    except:pass
    try: exp.log_parameters(best,prefix='best_')
    except:pass
    # Метрики
    train_metrics = f1_score(y_train, y_pred_train, average='micro')
    try: exp.log_metric(name='train',value=train_metrics)
    except:pass
    test_metrics = f1_score(y_test, y_pred_test, average='micro')
    try: exp.log_metric(name='test',value=test_metrics)
    except:pass
    # Время
    try: exp.log_metric(name='ts_mean',value=ts_mean)
    except:pass
    try: exp.log_metric(name='ts_max',value=ts_max)
    except:pass
    # График кривой обучения
    learn_curve_plot.plot()
    plt.title('Learning Curve')
    plt.yticks(np.arange(0,1.05,0.05))
    plt.grid()
    try: exp.log_figure(figure=plt)
    except:pass
    # Матрица ошибок
    try: exp.log_confusion_matrix(y_train.tolist(), y_pred_train.tolist())
    except:pass
    try: exp.log_confusion_matrix(y_test.tolist(), y_pred_test.tolist())
    except:pass
    
    exp.display()
    exp.end()