In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob

# Function to extract the participant ID from the filename
# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'participant' in part:
#             participant_id = part.replace('participant', '')
#             return int(participant_id)  # Convert to integer
#     return None  # If no participant ID found

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

In [3]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = '17participants'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
print(np.unique(labels))
labels.dtype

[0 1 2 3 4 5 6 7 8 9]


dtype('int64')

In [5]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# --- Training loop for 50 epochs ---
import mlflow
import mlflow.pytorch
num_epochs = 50

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()
        train_acc = correct / total
    mlflow.log_metric("train_loss", total_loss, step=epoch)
    mlflow.log_metric("train_accuracy", train_acc, step=epoch)
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)
            val_loss += loss.item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        val_loss = val_loss / len(test_loader)
        val_acc = correct / total
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_accuracy", val_acc, step=epoch)
        mlflow.pytorch.log_model(model, "models/last_model")
end_mlflow_experiment()       

2024/09/23 16:27:15 INFO mlflow.tracking.fluent: Experiment with name '17participants' does not exist. Creating a new experiment.
Epoch 1/50: 100%|██████████| 90/90 [02:13<00:00,  1.48s/it]


Epoch 1, Loss: 1.9316, Accuracy: 0.3110
Validation Loss: 1.7092, Validation Accuracy: 0.3843


Epoch 2/50: 100%|██████████| 90/90 [02:07<00:00,  1.42s/it]


Epoch 2, Loss: 1.4439, Accuracy: 0.5086
Validation Loss: 1.1441, Validation Accuracy: 0.5105


Epoch 3/50: 100%|██████████| 90/90 [02:04<00:00,  1.38s/it]


Epoch 3, Loss: 1.1834, Accuracy: 0.5995
Validation Loss: 2.4363, Validation Accuracy: 0.5316


Epoch 4/50: 100%|██████████| 90/90 [02:03<00:00,  1.37s/it]


Epoch 4, Loss: 0.9688, Accuracy: 0.6655
Validation Loss: 0.3694, Validation Accuracy: 0.5596


Epoch 5/50: 100%|██████████| 90/90 [02:04<00:00,  1.38s/it]


Epoch 5, Loss: 0.6801, Accuracy: 0.7883
Validation Loss: 0.3220, Validation Accuracy: 0.5736


Epoch 6/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 6, Loss: 0.5193, Accuracy: 0.8256
Validation Loss: 0.1002, Validation Accuracy: 0.5792


Epoch 7/50: 100%|██████████| 90/90 [02:01<00:00,  1.35s/it]


Epoch 7, Loss: 0.3164, Accuracy: 0.9063
Validation Loss: 0.0771, Validation Accuracy: 0.6045


Epoch 8/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 8, Loss: 0.2343, Accuracy: 0.9333
Validation Loss: 0.0319, Validation Accuracy: 0.5470


Epoch 9/50: 100%|██████████| 90/90 [02:00<00:00,  1.33s/it]


Epoch 9, Loss: 0.2102, Accuracy: 0.9414
Validation Loss: 0.0400, Validation Accuracy: 0.5680


Epoch 10/50: 100%|██████████| 90/90 [02:04<00:00,  1.38s/it]


Epoch 10, Loss: 0.1657, Accuracy: 0.9537
Validation Loss: 0.0466, Validation Accuracy: 0.5694


Epoch 11/50: 100%|██████████| 90/90 [02:03<00:00,  1.38s/it]


Epoch 11, Loss: 0.1142, Accuracy: 0.9698
Validation Loss: 0.0199, Validation Accuracy: 0.5947


Epoch 12/50: 100%|██████████| 90/90 [02:03<00:00,  1.37s/it]


Epoch 12, Loss: 0.1139, Accuracy: 0.9691
Validation Loss: 0.0110, Validation Accuracy: 0.5820


Epoch 13/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 13, Loss: 0.0899, Accuracy: 0.9726
Validation Loss: 0.0166, Validation Accuracy: 0.5863


Epoch 14/50: 100%|██████████| 90/90 [02:04<00:00,  1.39s/it]


Epoch 14, Loss: 0.0895, Accuracy: 0.9747
Validation Loss: 0.0133, Validation Accuracy: 0.6073


Epoch 15/50: 100%|██████████| 90/90 [02:01<00:00,  1.35s/it]


Epoch 15, Loss: 0.0654, Accuracy: 0.9828
Validation Loss: 0.0195, Validation Accuracy: 0.5806


Epoch 16/50: 100%|██████████| 90/90 [02:00<00:00,  1.34s/it]


Epoch 16, Loss: 0.0879, Accuracy: 0.9751
Validation Loss: 0.0333, Validation Accuracy: 0.5806


Epoch 17/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 17, Loss: 0.0565, Accuracy: 0.9867
Validation Loss: 0.0054, Validation Accuracy: 0.5820


Epoch 18/50: 100%|██████████| 90/90 [02:01<00:00,  1.35s/it]


Epoch 18, Loss: 0.0658, Accuracy: 0.9839
Validation Loss: 0.0077, Validation Accuracy: 0.5778


Epoch 19/50: 100%|██████████| 90/90 [02:04<00:00,  1.39s/it]


Epoch 19, Loss: 0.0404, Accuracy: 0.9888
Validation Loss: 0.0574, Validation Accuracy: 0.5722


Epoch 20/50: 100%|██████████| 90/90 [02:04<00:00,  1.39s/it]


Epoch 20, Loss: 0.0970, Accuracy: 0.9712
Validation Loss: 0.0305, Validation Accuracy: 0.6059


Epoch 21/50: 100%|██████████| 90/90 [02:06<00:00,  1.41s/it]


Epoch 21, Loss: 0.0329, Accuracy: 0.9923
Validation Loss: 0.0050, Validation Accuracy: 0.5947


Epoch 22/50: 100%|██████████| 90/90 [02:01<00:00,  1.35s/it]


Epoch 22, Loss: 0.0193, Accuracy: 0.9961
Validation Loss: 0.0078, Validation Accuracy: 0.6059


Epoch 23/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 23, Loss: 0.0205, Accuracy: 0.9965
Validation Loss: 0.0052, Validation Accuracy: 0.5736


Epoch 24/50: 100%|██████████| 90/90 [02:03<00:00,  1.37s/it]


Epoch 24, Loss: 0.1034, Accuracy: 0.9684
Validation Loss: 0.0142, Validation Accuracy: 0.5961


Epoch 25/50: 100%|██████████| 90/90 [02:04<00:00,  1.39s/it]


Epoch 25, Loss: 0.1166, Accuracy: 0.9649
Validation Loss: 0.0075, Validation Accuracy: 0.6115


Epoch 26/50: 100%|██████████| 90/90 [02:04<00:00,  1.38s/it]


Epoch 26, Loss: 0.0209, Accuracy: 0.9954
Validation Loss: 0.0063, Validation Accuracy: 0.6115


Epoch 27/50: 100%|██████████| 90/90 [01:59<00:00,  1.33s/it]


Epoch 27, Loss: 0.0237, Accuracy: 0.9937
Validation Loss: 0.0054, Validation Accuracy: 0.5975


Epoch 28/50: 100%|██████████| 90/90 [01:59<00:00,  1.32s/it]


Epoch 28, Loss: 0.0181, Accuracy: 0.9961
Validation Loss: 0.0036, Validation Accuracy: 0.6045


Epoch 29/50: 100%|██████████| 90/90 [02:02<00:00,  1.36s/it]


Epoch 29, Loss: 0.0385, Accuracy: 0.9895
Validation Loss: 0.0017, Validation Accuracy: 0.5806


Epoch 30/50: 100%|██████████| 90/90 [02:04<00:00,  1.38s/it]


Epoch 30, Loss: 0.0408, Accuracy: 0.9867
Validation Loss: 0.0023, Validation Accuracy: 0.5975


Epoch 31/50: 100%|██████████| 90/90 [02:05<00:00,  1.39s/it]


Epoch 31, Loss: 0.0742, Accuracy: 0.9814
Validation Loss: 0.0959, Validation Accuracy: 0.6101


Epoch 32/50: 100%|██████████| 90/90 [02:05<00:00,  1.39s/it]


Epoch 32, Loss: 0.0944, Accuracy: 0.9737
Validation Loss: 0.0041, Validation Accuracy: 0.5891


Epoch 33/50: 100%|██████████| 90/90 [02:01<00:00,  1.35s/it]


Epoch 33, Loss: 0.0556, Accuracy: 0.9856
Validation Loss: 0.0066, Validation Accuracy: 0.5947


Epoch 34/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 34, Loss: 0.0166, Accuracy: 0.9958
Validation Loss: 0.0050, Validation Accuracy: 0.5526


Epoch 35/50: 100%|██████████| 90/90 [01:09<00:00,  1.29it/s]


Epoch 35, Loss: 0.0327, Accuracy: 0.9912
Validation Loss: 0.0027, Validation Accuracy: 0.5666


Epoch 36/50: 100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


Epoch 36, Loss: 0.0251, Accuracy: 0.9926
Validation Loss: 0.0012, Validation Accuracy: 0.5933


Epoch 37/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 37, Loss: 0.0063, Accuracy: 0.9993
Validation Loss: 0.0018, Validation Accuracy: 0.6073


Epoch 38/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 38, Loss: 0.0039, Accuracy: 0.9996
Validation Loss: 0.0031, Validation Accuracy: 0.6199


Epoch 39/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 39, Loss: 0.0022, Accuracy: 1.0000
Validation Loss: 0.0029, Validation Accuracy: 0.6213


Epoch 40/50: 100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


Epoch 40, Loss: 0.0019, Accuracy: 1.0000
Validation Loss: 0.0015, Validation Accuracy: 0.6199


Epoch 41/50: 100%|██████████| 90/90 [01:11<00:00,  1.27it/s]


Epoch 41, Loss: 0.0018, Accuracy: 1.0000
Validation Loss: 0.0018, Validation Accuracy: 0.6199


Epoch 42/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 42, Loss: 0.0016, Accuracy: 1.0000
Validation Loss: 0.0017, Validation Accuracy: 0.6185


Epoch 43/50: 100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


Epoch 43, Loss: 0.0015, Accuracy: 1.0000
Validation Loss: 0.0014, Validation Accuracy: 0.6213


Epoch 44/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 44, Loss: 0.0014, Accuracy: 1.0000
Validation Loss: 0.0015, Validation Accuracy: 0.6241


Epoch 45/50: 100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


Epoch 45, Loss: 0.0013, Accuracy: 1.0000
Validation Loss: 0.0010, Validation Accuracy: 0.6241


Epoch 46/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 46, Loss: 0.0012, Accuracy: 1.0000
Validation Loss: 0.0012, Validation Accuracy: 0.6255


Epoch 47/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 47, Loss: 0.0011, Accuracy: 1.0000
Validation Loss: 0.0011, Validation Accuracy: 0.6241


Epoch 48/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 48, Loss: 0.0011, Accuracy: 1.0000
Validation Loss: 0.0014, Validation Accuracy: 0.6255


Epoch 49/50: 100%|██████████| 90/90 [01:10<00:00,  1.28it/s]


Epoch 49, Loss: 0.0010, Accuracy: 1.0000
Validation Loss: 0.0010, Validation Accuracy: 0.6255


Epoch 50/50: 100%|██████████| 90/90 [01:10<00:00,  1.27it/s]


Epoch 50, Loss: 0.0010, Accuracy: 1.0000
Validation Loss: 0.0011, Validation Accuracy: 0.6269


