In [1]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob

# Function to extract the participant ID from the filename
# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'participant' in part:
#             participant_id = part.replace('participant', '')
#             return int(participant_id)  # Convert to integer
#     return None  # If no participant ID found

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [3]:
print(np.unique(labels))
labels.dtype

[0 1 2 3 4 5 6 7 8 9]


dtype('int64')

In [4]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# --- Training loop for 50 epochs ---
import mlflow
import mlflow.pytorch
num_epochs = 50

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()
        train_acc = correct / total
    mlflow.log_metric("train_loss", total_loss, step=epoch)
    mlflow.log_metric("train_accuracy", train_acc, step=epoch)
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)
            val_loss += loss.item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        val_loss = val_loss / len(test_loader)
        val_acc = correct / total
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("val_accuracy", val_acc, step=epoch)
        mlflow.pytorch.log_model(model, "models/last_model")
end_mlflow_experiment()       

Epoch 1/50: 100%|██████████| 226/226 [06:46<00:00,  1.80s/it]


Epoch 1, Loss: 1.6811, Accuracy: 0.4106
Validation Loss: 2.1019, Validation Accuracy: 0.5363


Epoch 2/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 2, Loss: 1.1208, Accuracy: 0.6152
Validation Loss: 0.9974, Validation Accuracy: 0.6206


Epoch 3/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 3, Loss: 0.8540, Accuracy: 0.7085
Validation Loss: 0.8121, Validation Accuracy: 0.6694


Epoch 4/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 4, Loss: 0.6722, Accuracy: 0.7685
Validation Loss: 1.1750, Validation Accuracy: 0.6667


Epoch 5/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 5, Loss: 0.4779, Accuracy: 0.8408
Validation Loss: 0.5102, Validation Accuracy: 0.6450


Epoch 6/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 6, Loss: 0.3168, Accuracy: 0.8997
Validation Loss: 0.4131, Validation Accuracy: 0.6400


Epoch 7/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 7, Loss: 0.2067, Accuracy: 0.9352
Validation Loss: 0.2342, Validation Accuracy: 0.6722


Epoch 8/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 8, Loss: 0.1698, Accuracy: 0.9459
Validation Loss: 0.2920, Validation Accuracy: 0.6489


Epoch 9/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 9, Loss: 0.0976, Accuracy: 0.9728
Validation Loss: 0.3448, Validation Accuracy: 0.6828


Epoch 10/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 10, Loss: 0.0988, Accuracy: 0.9678
Validation Loss: 0.0145, Validation Accuracy: 0.6717


Epoch 11/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 11, Loss: 0.0966, Accuracy: 0.9694
Validation Loss: 0.0269, Validation Accuracy: 0.6789


Epoch 12/50: 100%|██████████| 226/226 [03:02<00:00,  1.24it/s]


Epoch 12, Loss: 0.0787, Accuracy: 0.9773
Validation Loss: 0.0101, Validation Accuracy: 0.6644


Epoch 13/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 13, Loss: 0.0749, Accuracy: 0.9771
Validation Loss: 0.1872, Validation Accuracy: 0.6539


Epoch 14/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 14, Loss: 0.0984, Accuracy: 0.9691
Validation Loss: 0.0201, Validation Accuracy: 0.6700


Epoch 15/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 15, Loss: 0.0307, Accuracy: 0.9911
Validation Loss: 0.0160, Validation Accuracy: 0.6661


Epoch 16/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 16, Loss: 0.0404, Accuracy: 0.9881
Validation Loss: 0.0456, Validation Accuracy: 0.6334


Epoch 17/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 17, Loss: 0.0445, Accuracy: 0.9877
Validation Loss: 0.0050, Validation Accuracy: 0.6561


Epoch 18/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 18, Loss: 0.0733, Accuracy: 0.9760
Validation Loss: 0.0102, Validation Accuracy: 0.6689


Epoch 19/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 19, Loss: 0.0456, Accuracy: 0.9852
Validation Loss: 0.0029, Validation Accuracy: 0.6683


Epoch 20/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 20, Loss: 0.0418, Accuracy: 0.9885
Validation Loss: 0.0058, Validation Accuracy: 0.6667


Epoch 21/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 21, Loss: 0.0353, Accuracy: 0.9903
Validation Loss: 0.0094, Validation Accuracy: 0.6894


Epoch 22/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 22, Loss: 0.0650, Accuracy: 0.9800
Validation Loss: 0.0029, Validation Accuracy: 0.6722


Epoch 23/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 23, Loss: 0.0349, Accuracy: 0.9893
Validation Loss: 0.0023, Validation Accuracy: 0.6750


Epoch 24/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 24, Loss: 0.0293, Accuracy: 0.9903
Validation Loss: 0.0291, Validation Accuracy: 0.6750


Epoch 25/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 25, Loss: 0.0302, Accuracy: 0.9910
Validation Loss: 0.0017, Validation Accuracy: 0.6894


Epoch 26/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 26, Loss: 0.0518, Accuracy: 0.9825
Validation Loss: 0.0021, Validation Accuracy: 0.6894


Epoch 27/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 27, Loss: 0.0509, Accuracy: 0.9849
Validation Loss: 0.1400, Validation Accuracy: 0.6661


Epoch 28/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 28, Loss: 0.0482, Accuracy: 0.9854
Validation Loss: 0.0026, Validation Accuracy: 0.6828


Epoch 29/50: 100%|██████████| 226/226 [03:02<00:00,  1.24it/s]


Epoch 29, Loss: 0.0404, Accuracy: 0.9885
Validation Loss: 0.0021, Validation Accuracy: 0.6800


Epoch 30/50: 100%|██████████| 226/226 [03:02<00:00,  1.24it/s]


Epoch 30, Loss: 0.0430, Accuracy: 0.9868
Validation Loss: 0.0025, Validation Accuracy: 0.6883


Epoch 31/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 31, Loss: 0.0321, Accuracy: 0.9908
Validation Loss: 0.0084, Validation Accuracy: 0.6650


Epoch 32/50: 100%|██████████| 226/226 [03:00<00:00,  1.25it/s]


Epoch 32, Loss: 0.0185, Accuracy: 0.9953
Validation Loss: 0.0014, Validation Accuracy: 0.6722


Epoch 33/50: 100%|██████████| 226/226 [03:02<00:00,  1.24it/s]


Epoch 33, Loss: 0.0612, Accuracy: 0.9834
Validation Loss: 0.0085, Validation Accuracy: 0.6606


Epoch 34/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 34, Loss: 0.0221, Accuracy: 0.9925
Validation Loss: 0.0022, Validation Accuracy: 0.6694


Epoch 35/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 35, Loss: 0.0270, Accuracy: 0.9924
Validation Loss: 0.0146, Validation Accuracy: 0.6728


Epoch 36/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 36, Loss: 0.0241, Accuracy: 0.9924
Validation Loss: 0.0016, Validation Accuracy: 0.6705


Epoch 37/50: 100%|██████████| 226/226 [03:02<00:00,  1.24it/s]


Epoch 37, Loss: 0.0561, Accuracy: 0.9843
Validation Loss: 0.0031, Validation Accuracy: 0.6633


Epoch 38/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 38, Loss: 0.0467, Accuracy: 0.9874
Validation Loss: 0.5411, Validation Accuracy: 0.6717


Epoch 39/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 39, Loss: 0.0240, Accuracy: 0.9932
Validation Loss: 0.0012, Validation Accuracy: 0.6772


Epoch 40/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 40, Loss: 0.0278, Accuracy: 0.9915
Validation Loss: 0.0193, Validation Accuracy: 0.6750


Epoch 41/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 41, Loss: 0.0247, Accuracy: 0.9928
Validation Loss: 0.3684, Validation Accuracy: 0.6611


Epoch 42/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 42, Loss: 0.0294, Accuracy: 0.9914
Validation Loss: 0.1834, Validation Accuracy: 0.6600


Epoch 43/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 43, Loss: 0.0505, Accuracy: 0.9847
Validation Loss: 0.0108, Validation Accuracy: 0.6761


Epoch 44/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 44, Loss: 0.0234, Accuracy: 0.9936
Validation Loss: 0.0023, Validation Accuracy: 0.6600


Epoch 45/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 45, Loss: 0.0223, Accuracy: 0.9935
Validation Loss: 0.0026, Validation Accuracy: 0.6423


Epoch 46/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 46, Loss: 0.0313, Accuracy: 0.9900
Validation Loss: 0.0011, Validation Accuracy: 0.6500


Epoch 47/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 47, Loss: 0.0257, Accuracy: 0.9922
Validation Loss: 0.0045, Validation Accuracy: 0.6816


Epoch 48/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 48, Loss: 0.0078, Accuracy: 0.9985
Validation Loss: 0.0030, Validation Accuracy: 0.6900


Epoch 49/50: 100%|██████████| 226/226 [03:01<00:00,  1.25it/s]


Epoch 49, Loss: 0.0043, Accuracy: 0.9988
Validation Loss: 0.0005, Validation Accuracy: 0.6866


Epoch 50/50: 100%|██████████| 226/226 [03:01<00:00,  1.24it/s]


Epoch 50, Loss: 0.0017, Accuracy: 0.9996
Validation Loss: 0.0005, Validation Accuracy: 0.6900


