In [26]:
from data_loader import EMGDataManager

In [27]:
file_path = "data/emg_handposes/sub-01/ses-01/emg_trial_data.pkl"

In [28]:
manager = EMGDataManager(file_path=file_path, sample_rate=250, cutoff_freq=20.0, selected_channels=[1, 4, 7])

In [29]:
data_by_label = manager.load_and_group()
print("Loaded data for labels:", list(data_by_label.keys()))

Loaded data for labels: ['Rest', 'fist', 'flat', 'okay', 'two']


In [30]:
processed_data_by_label = manager.preprocess(data_by_label)
print("\nPreprocessing complete. Sample shapes per label:")
for label, trials in processed_data_by_label.items():
    if len(trials) > 0:
        print(f"  Label '{label}': {len(trials)} trial(s), Sample trial shape: {trials[0].shape}")


Preprocessing complete. Sample shapes per label:
  Label 'Rest': 40 trial(s), Sample trial shape: (3, 1050)
  Label 'fist': 10 trial(s), Sample trial shape: (3, 1050)
  Label 'flat': 10 trial(s), Sample trial shape: (3, 1050)
  Label 'okay': 10 trial(s), Sample trial shape: (3, 1050)
  Label 'two': 10 trial(s), Sample trial shape: (3, 1050)


In [31]:
X_windowed, y_windowed = manager.window_data(processed_data_by_label, window_length_sec=0.42, overlap=0)
print("\nWindowing complete.")
print(f"Windowed data shape: {X_windowed.shape}")

[[ -0.068  -4.869  -4.547   8.221   2.924  -7.951  -7.065   5.731  10.777  -0.13  -12.734   4.706
    5.681  -2.116  -6.358  -2.798   7.948   2.444  -6.598   1.022   8.48    0.09  -10.427  -3.215
    4.524   3.259  -2.1     1.466  -0.51    3.968   2.723  -9.129  -3.18    9.199  -1.683  -2.208
   -2.192   2.72   -0.125  -4.105  -0.797   7.51    6.157  -4.658  -3.003   0.969  -1.232  -4.32
   -4.495   0.143   3.467   7.366   2.558   2.654   0.671  -5.544  -7.296  -2.159   4.88    0.917
   -4.76    1.902   2.777   0.716   3.242   0.701  -2.986   3.213  -3.782  -3.972   3.331   1.273
   -3.902  -0.776   0.988  -1.012   4.405   9.042  -5.567  -7.291  -1.248   1.651   2.693   0.424
   -0.506  -0.676   1.985   2.078  -2.322  -3.783   2.083   2.275  -2.383  -2.139  -3.09    5.385
    0.399   3.527   5.495  -5.005  -7.187  -3.118   4.109   1.515  -1.261]
 [  0.067   5.474   8.334   3.436  -4.026  -2.402   1.857  -0.846  -4.588  -2.082   4.26    3.158
   -3.088  -5.063   5.94    3.894  -2.818  -

In [32]:
X_balanced, y_balanced = manager.balance_data(X_windowed, y_windowed, target_label="Rest" )
print("\nBalancing complete.")
print(f"Balanced data shape: {X_balanced.shape}")
print(f"Balanced data labels: {len(y_balanced)}")



Balancing complete.
Balanced data shape: (500, 3, 105)
Balanced data labels: 500


In [33]:
features = manager.extract_features(X_balanced)
print("\nFeature extraction complete.")
print(f"Features shape: {features.shape}")


Feature extraction complete.
Features shape: (500, 12)


In [34]:
if len(features) > 0:
    print("\nExample feature vector:")
    print(features[444])
    print("Corresponding label:", y_windowed[444])
else:
    print("No features extracted; check your data and windowing parameters.")



Example feature vector:
[  26.48    61.      46.    3412.677    9.073   64.      56.    1301.834   11.165   62.      55.
 1576.036]
Corresponding label: fist


In [35]:
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier

In [36]:
X_train, X_test, y_train, y_test = train_test_split(features, y_balanced, test_size=0.2, random_state=42)

In [37]:
X_scaler = StandardScaler()
X_train_scaled = X_scaler.fit_transform(X_train)
X_test_scaled = X_scaler.transform(X_test)



In [38]:
label_map = {0: "fist", 1: "flat", 2: "okay", 3: "two", 4:"rest"}
class_names = [label_map[i] for i in sorted(label_map.keys())]

In [39]:
param_grid_svm = {
    'C': [0.01, 0.1, 1, 10, 100],
    'gamma': [0.001, 0.01, 0.1, 1],
    'kernel': ['rbf', 'linear']
}

param_grid_knn = {
    'n_neighbors': [1, 3, 5, 7, 9, 11],
    'weights': ['uniform', 'distance'],
    'p': [1, 2] 
}

In [40]:
svm_clf = SVC()
grid_search = GridSearchCV(
    estimator=svm_clf,
    param_grid=param_grid_svm,
    scoring='accuracy',  
    cv=5,                
    verbose=1,           
    n_jobs=-1            
)

grid_search.fit(X_train_scaled, y_train)

best_svm = grid_search.best_estimator_
print("Best Hyperparameters:", grid_search.best_params_)
print("Best CV Accuracy:", grid_search.best_score_)

Fitting 5 folds for each of 40 candidates, totalling 200 fits
Best Hyperparameters: {'C': 10, 'gamma': 0.001, 'kernel': 'linear'}
Best CV Accuracy: 0.8724999999999999


In [41]:
import pickle
from sklearn.pipeline import make_pipeline

best_pipeline = make_pipeline(
    StandardScaler(), 
    best_svm
)

with open('hand_pose_classifier.pkl', 'wb') as f:
    pickle.dump({
        'pipeline': best_pipeline,
        'label_map': label_map,
        'feature_params': {
            'sample_rate': manager.sample_rate,
            'selected_channels': manager.selected_channels,
            'window_length_sec': 0.42  
        },
        'scaler': X_scaler,
        'X_test': X_test_scaled,
        'y_test': y_test
    }, f)

with open('scaler.pkl', 'wb') as f:
    pickle.dump(X_scaler, f)

print("Saved model pipeline and scaler to disk")

Saved model pipeline and scaler to disk


In [42]:
import os
import time
import mujoco
import mujoco.viewer
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [43]:
model = mujoco.MjModel.from_xml_path("/Users/jmalegaonkar/Desktop/EMGrip/Adroit/Adroit_hand.xml")
data = mujoco.MjData(model)
# Debug print to understand model structure
print(f"Number of actuators (controls): {model.nu}")
print("Actuator names:", [model.actuator(i).name for i in range(model.nu)])

Number of actuators (controls): 24
Actuator names: ['A_WRJ1', 'A_WRJ0', 'A_FFJ3', 'A_FFJ2', 'A_FFJ1', 'A_FFJ0', 'A_MFJ3', 'A_MFJ2', 'A_MFJ1', 'A_MFJ0', 'A_RFJ3', 'A_RFJ2', 'A_RFJ1', 'A_RFJ0', 'A_LFJ4', 'A_LFJ3', 'A_LFJ2', 'A_LFJ1', 'A_LFJ0', 'A_THJ4', 'A_THJ3', 'A_THJ2', 'A_THJ1', 'A_THJ0']


In [44]:
joint_map = {}
for i in range(model.nu):
    joint_id = model.actuator(i).trnid[0]
    joint_name = model.joint(joint_id).name
    joint_map[joint_name] = i

print("\nJoint mapping:")
for name, idx in joint_map.items():
    print(f"{name}: {idx}")


Joint mapping:
WRJ1: 0
WRJ0: 1
FFJ3: 2
FFJ2: 3
FFJ1: 4
FFJ0: 5
MFJ3: 6
MFJ2: 7
MFJ1: 8
MFJ0: 9
RFJ3: 10
RFJ2: 11
RFJ1: 12
RFJ0: 13
LFJ4: 14
LFJ3: 15
LFJ2: 16
LFJ1: 17
LFJ0: 18
THJ4: 19
THJ3: 20
THJ2: 21
THJ1: 22
THJ0: 23


In [45]:
def get_desired_configuration(command):
    desired_q = np.zeros(model.nu)
    try:
        if command.lower() == "rest":
            desired_q[:] = 0.0
            
        elif command.lower() == "fist":
            # Close all fingers
            for name in joint_map:
                if "FJ2" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "FJ1" in name:
                    desired_q[joint_map[name]] = 1.0
                elif "THJ3" in name:
                    desired_q[joint_map[name]] = 1.3
                elif "THJ0" in name: 
                    desired_q[joint_map[name]] = -1.57
                elif "THJ1" in name: 
                    desired_q[joint_map[name]] = -0.52
                elif "THJ2" in name:
                    desired_q[joint_map[name]] = 0.1
                else:
                    desired_q[joint_map[name]] = 0.0
                
                    
        elif command.lower() == "okay":
            for name in joint_map:
                # Index finger flexion
                if "FFJ1" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "FFJ2" in name:
                    desired_q[joint_map[name]] = 1.0
                # Thumb opposition
                elif "THJ4" in name:
                    desired_q[joint_map[name]] = 0.2  # Abduction
                elif "THJ3" in name:
                    desired_q[joint_map[name]] = 1.0  # Flexion
                elif "THJ0" in name:
                    desired_q[joint_map[name]] = -1.0  # Tip flexion
                elif "THJ1" in name:
                    desired_q[joint_map[name]] = -0.075
                elif "THJ2" in name:
                    desired_q[joint_map[name]] = 0.135
                else:
                    desired_q[joint_map[name]] = 0.0
        elif command.lower() == "two":
            for name in joint_map:
                if "FFJ3" in name:
                    desired_q[joint_map[name]] = 0.44
                elif "MFJ3" in name:
                    desired_q[joint_map[name]] = -0.44
                elif "RFJ1" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "RFJ2" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "LFJ2" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "LFJ1" in name:
                    desired_q[joint_map[name]] = 1.6
                elif "THJ4" in name:
                    desired_q[joint_map[name]] = 1.0
                elif "THJ3" in name:
                    desired_q[joint_map[name]] = 1.3
                elif "THJ2" in name:
                    desired_q[joint_map[name]] = 0.3
                elif "THJ1" in name:
                    desired_q[joint_map[name]] = -0.5
                elif "THJ0" in name:
                    desired_q[joint_map[name]] = -1.57
                else:
                    desired_q[joint_map[name]] = 0.0
        elif command.lower() == "flat":
            for name in joint_map:
                if "FFJ3" in name:
                    desired_q[joint_map[name]] = 0.44
                elif "MFJ3" in name:
                    desired_q[joint_map[name]] = 0.09
                elif "RFJ3" in name:
                    desired_q[joint_map[name]] = -0.4
                elif "LFJ3" in name:
                    desired_q[joint_map[name]] = -0.44
                else:
                    desired_q[joint_map[name]] = 0.0
            
        else:
            return None
    
            
    except KeyError as e:
        print(f"Missing joint in configuration: {e}")
        return None
        
    return desired_q

In [46]:
def show_hand(command):
    print(command)
    frames = []

    if command.lower() == "exit":
        return
    
    q_desired = get_desired_configuration(command)
    if q_desired is None:
        print("Invalid command")
        return
        
    # Reset simulation
    mujoco.mj_resetData(model, data)

    # Create renderer
    renderer = mujoco.Renderer(model, height=480, width=640)

    # Animation parameters
    duration = 2.0  # Time to reach the desired configuration
    start_time = time.time()
    frames = []  # Reset frames for new command

    # Initial joint positions
    q_start = data.qpos[:model.nu].copy()

    while time.time() - start_time < duration:
        # Interpolation factor (0 to 1)
        alpha = (time.time() - start_time) / duration
        alpha = min(alpha, 1.0)  # Clamp to 1.0
        
        # Interpolate between start and desired positions
        data.qpos[:model.nu] = q_start + alpha * (q_desired - q_start)
        
        # Update model state
        mujoco.mj_forward(model, data)
        
        # Render frame
        renderer.update_scene(data, camera=-1)
        frame = renderer.render()
        frames.append(frame)

    # Show video of the action
    media.show_video(frames, fps=60)
    time.sleep(5)  # Pause before next command

In [47]:
y_pred_svm = best_svm.predict(X_test_scaled[22].reshape(1, -1))
out = str(y_pred_svm[0])
print(out)

flat


In [None]:
show_hand(out)