In [1]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel
import mlflow
import mlflow.pytorch

root_folder = 's1_s2_s3_matfiles_5gest'  # Update with your folder path
num_epochs = 500

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
    
#     for part in parts:
#         if 'participant' in part:
#             participant_number = part.replace('participant', '')
#             return int(participant_number)  # Convert to integer
#     return None  # If no participant label found

def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model_5gest.pth'
last_model_path = 'last_model_5gest.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    mlflow.log_metric("train_loss", epoch_loss, step=epoch)
    mlflow.log_metric("train_accuracy", epoch_acc, step=epoch)

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    mlflow.log_metric("val_loss", val_loss, step=epoch)
    mlflow.log_metric("val_accuracy", val_acc, step=epoch)

    torch.save(model.state_dict(), last_model_path)
    mlflow.pytorch.log_model(model, "models/last_model")
    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')
        mlflow.pytorch.log_model(model, "models/best_model")

end_mlflow_experiment()        

100%|██████████| 15/15 [01:10<00:00,  4.73s/it]


Epoch 445/500, Loss: 0.2721, Accuracy: 0.9046
Validation Loss: 1.0659, Validation Accuracy: 0.7295


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 446/500, Loss: 0.2906, Accuracy: 0.8921
Validation Loss: 1.0691, Validation Accuracy: 0.7228


100%|██████████| 15/15 [01:09<00:00,  4.66s/it]


Epoch 447/500, Loss: 0.2697, Accuracy: 0.9087
Validation Loss: 0.9601, Validation Accuracy: 0.7583




New best model saved with accuracy: 0.7583


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 448/500, Loss: 0.3078, Accuracy: 0.8877
Validation Loss: 1.0821, Validation Accuracy: 0.7228


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 449/500, Loss: 0.2960, Accuracy: 0.8899
Validation Loss: 1.0338, Validation Accuracy: 0.7395


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 450/500, Loss: 0.3267, Accuracy: 0.8882
Validation Loss: 0.9626, Validation Accuracy: 0.7517


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 451/500, Loss: 0.3013, Accuracy: 0.8932
Validation Loss: 1.0483, Validation Accuracy: 0.7206


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 452/500, Loss: 0.2987, Accuracy: 0.8996
Validation Loss: 1.0119, Validation Accuracy: 0.7384


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 453/500, Loss: 0.2709, Accuracy: 0.9024
Validation Loss: 1.0422, Validation Accuracy: 0.7350


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 454/500, Loss: 0.2957, Accuracy: 0.8888
Validation Loss: 1.1866, Validation Accuracy: 0.6907


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 455/500, Loss: 0.3493, Accuracy: 0.8852
Validation Loss: 1.2772, Validation Accuracy: 0.6785


100%|██████████| 15/15 [01:09<00:00,  4.60s/it]


Epoch 456/500, Loss: 0.3161, Accuracy: 0.8818
Validation Loss: 1.1284, Validation Accuracy: 0.6973


100%|██████████| 15/15 [01:08<00:00,  4.60s/it]


Epoch 457/500, Loss: 0.2829, Accuracy: 0.8943
Validation Loss: 1.1274, Validation Accuracy: 0.7095


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 458/500, Loss: 0.3587, Accuracy: 0.8749
Validation Loss: 1.1019, Validation Accuracy: 0.7151


100%|██████████| 15/15 [01:09<00:00,  4.60s/it]


Epoch 459/500, Loss: 0.2850, Accuracy: 0.8979
Validation Loss: 1.1995, Validation Accuracy: 0.6929


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 460/500, Loss: 0.3329, Accuracy: 0.8774
Validation Loss: 1.2021, Validation Accuracy: 0.6918


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 461/500, Loss: 0.3684, Accuracy: 0.8705
Validation Loss: 1.1296, Validation Accuracy: 0.7118


100%|██████████| 15/15 [01:08<00:00,  4.58s/it]


Epoch 462/500, Loss: 0.3907, Accuracy: 0.8594
Validation Loss: 1.2328, Validation Accuracy: 0.6818


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 463/500, Loss: 0.2856, Accuracy: 0.8904
Validation Loss: 1.1815, Validation Accuracy: 0.6885


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 464/500, Loss: 0.3109, Accuracy: 0.8943
Validation Loss: 1.0553, Validation Accuracy: 0.7195


100%|██████████| 15/15 [01:09<00:00,  4.60s/it]


Epoch 465/500, Loss: 0.2869, Accuracy: 0.8949
Validation Loss: 1.1275, Validation Accuracy: 0.7228


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 466/500, Loss: 0.3157, Accuracy: 0.8854
Validation Loss: 1.1557, Validation Accuracy: 0.6973


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 467/500, Loss: 0.2992, Accuracy: 0.8910
Validation Loss: 1.0793, Validation Accuracy: 0.7217


100%|██████████| 15/15 [01:09<00:00,  4.60s/it]


Epoch 468/500, Loss: 0.2925, Accuracy: 0.8857
Validation Loss: 1.0093, Validation Accuracy: 0.7384


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 469/500, Loss: 0.3099, Accuracy: 0.8960
Validation Loss: 1.0498, Validation Accuracy: 0.7361


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 470/500, Loss: 0.3077, Accuracy: 0.8860
Validation Loss: 1.1066, Validation Accuracy: 0.7018


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 471/500, Loss: 0.2944, Accuracy: 0.8924
Validation Loss: 1.0662, Validation Accuracy: 0.7206


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 472/500, Loss: 0.2603, Accuracy: 0.9051
Validation Loss: 1.0336, Validation Accuracy: 0.7306


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 473/500, Loss: 0.2673, Accuracy: 0.9112
Validation Loss: 1.2425, Validation Accuracy: 0.6940


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 474/500, Loss: 0.2629, Accuracy: 0.9057
Validation Loss: 0.9973, Validation Accuracy: 0.7461


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 475/500, Loss: 0.2459, Accuracy: 0.9148
Validation Loss: 1.0161, Validation Accuracy: 0.7339


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 476/500, Loss: 0.2461, Accuracy: 0.9060
Validation Loss: 1.1167, Validation Accuracy: 0.7217


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 477/500, Loss: 0.2578, Accuracy: 0.9076
Validation Loss: 1.0018, Validation Accuracy: 0.7517


100%|██████████| 15/15 [01:08<00:00,  4.60s/it]


Epoch 478/500, Loss: 0.2605, Accuracy: 0.9132
Validation Loss: 1.1020, Validation Accuracy: 0.7328


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 479/500, Loss: 0.2787, Accuracy: 0.9046
Validation Loss: 1.0232, Validation Accuracy: 0.7450


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 480/500, Loss: 0.2587, Accuracy: 0.9207
Validation Loss: 1.1259, Validation Accuracy: 0.7228


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 481/500, Loss: 0.2490, Accuracy: 0.9190
Validation Loss: 1.0917, Validation Accuracy: 0.7295


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 482/500, Loss: 0.2158, Accuracy: 0.9229
Validation Loss: 1.1962, Validation Accuracy: 0.7040


100%|██████████| 15/15 [01:09<00:00,  4.60s/it]


Epoch 483/500, Loss: 0.2689, Accuracy: 0.9043
Validation Loss: 1.0689, Validation Accuracy: 0.7328


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 484/500, Loss: 0.2200, Accuracy: 0.9193
Validation Loss: 1.0527, Validation Accuracy: 0.7328


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 485/500, Loss: 0.2082, Accuracy: 0.9337
Validation Loss: 1.0237, Validation Accuracy: 0.7461


100%|██████████| 15/15 [01:08<00:00,  4.58s/it]


Epoch 486/500, Loss: 0.2029, Accuracy: 0.9290
Validation Loss: 1.0668, Validation Accuracy: 0.7384


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 487/500, Loss: 0.2012, Accuracy: 0.9334
Validation Loss: 1.0568, Validation Accuracy: 0.7428


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 488/500, Loss: 0.2067, Accuracy: 0.9301
Validation Loss: 1.0942, Validation Accuracy: 0.7361


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 489/500, Loss: 0.1954, Accuracy: 0.9262
Validation Loss: 1.0746, Validation Accuracy: 0.7417


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 490/500, Loss: 0.2029, Accuracy: 0.9348
Validation Loss: 1.0812, Validation Accuracy: 0.7328


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 491/500, Loss: 0.2074, Accuracy: 0.9417
Validation Loss: 1.0867, Validation Accuracy: 0.7295


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 492/500, Loss: 0.1984, Accuracy: 0.9295
Validation Loss: 1.1100, Validation Accuracy: 0.7306


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 493/500, Loss: 0.1934, Accuracy: 0.9337
Validation Loss: 1.0741, Validation Accuracy: 0.7406


100%|██████████| 15/15 [01:09<00:00,  4.63s/it]


Epoch 494/500, Loss: 0.1775, Accuracy: 0.9345
Validation Loss: 1.0769, Validation Accuracy: 0.7406


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 495/500, Loss: 0.1875, Accuracy: 0.9329
Validation Loss: 1.0795, Validation Accuracy: 0.7395


100%|██████████| 15/15 [01:09<00:00,  4.62s/it]


Epoch 496/500, Loss: 0.1784, Accuracy: 0.9359
Validation Loss: 1.0797, Validation Accuracy: 0.7395


100%|██████████| 15/15 [01:09<00:00,  4.64s/it]


Epoch 497/500, Loss: 0.1831, Accuracy: 0.9406
Validation Loss: 1.0879, Validation Accuracy: 0.7350


100%|██████████| 15/15 [01:08<00:00,  4.59s/it]


Epoch 498/500, Loss: 0.1859, Accuracy: 0.9348
Validation Loss: 1.0927, Validation Accuracy: 0.7350


100%|██████████| 15/15 [01:08<00:00,  4.60s/it]


Epoch 499/500, Loss: 0.1834, Accuracy: 0.9356
Validation Loss: 1.0724, Validation Accuracy: 0.7439


100%|██████████| 15/15 [01:09<00:00,  4.61s/it]


Epoch 500/500, Loss: 0.1895, Accuracy: 0.9393
Validation Loss: 1.0957, Validation Accuracy: 0.7328




In [3]:
print(best_val_acc)
sds

0.7583148558758315


NameError: name 'sds' is not defined

In [None]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel

def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5gest'  # Update with your folder path
num_epochs = 500
# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=16,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 278/500, Loss: 0.6421, Accuracy: 0.7662
Validation Loss: 1.0330, Validation Accuracy: 0.6741


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 279/500, Loss: 0.6037, Accuracy: 0.7764
Validation Loss: 1.0538, Validation Accuracy: 0.6608


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 280/500, Loss: 0.6451, Accuracy: 0.7603
Validation Loss: 1.0149, Validation Accuracy: 0.6574


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 281/500, Loss: 0.6593, Accuracy: 0.7548
Validation Loss: 0.9941, Validation Accuracy: 0.6840


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 282/500, Loss: 0.6090, Accuracy: 0.7753
Validation Loss: 1.0158, Validation Accuracy: 0.6630


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 283/500, Loss: 0.6885, Accuracy: 0.7537
Validation Loss: 1.0696, Validation Accuracy: 0.6707


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 284/500, Loss: 0.6055, Accuracy: 0.7825
Validation Loss: 0.9852, Validation Accuracy: 0.6929


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 285/500, Loss: 0.6229, Accuracy: 0.7664
Validation Loss: 0.9284, Validation Accuracy: 0.6951


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 286/500, Loss: 0.5820, Accuracy: 0.7850
Validation Loss: 1.0500, Validation Accuracy: 0.6663


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 287/500, Loss: 0.6361, Accuracy: 0.7795
Validation Loss: 0.9601, Validation Accuracy: 0.6885


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 288/500, Loss: 0.5757, Accuracy: 0.7881
Validation Loss: 1.0131, Validation Accuracy: 0.6763


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 289/500, Loss: 0.5954, Accuracy: 0.7856
Validation Loss: 1.0185, Validation Accuracy: 0.6696


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 290/500, Loss: 0.5702, Accuracy: 0.7889
Validation Loss: 0.9901, Validation Accuracy: 0.6707


100%|██████████| 15/15 [01:06<00:00,  4.40s/it]


Epoch 291/500, Loss: 0.5562, Accuracy: 0.7897
Validation Loss: 0.9422, Validation Accuracy: 0.6929


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 292/500, Loss: 0.5682, Accuracy: 0.7922
Validation Loss: 0.9785, Validation Accuracy: 0.6874


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 293/500, Loss: 0.5553, Accuracy: 0.8003
Validation Loss: 0.9732, Validation Accuracy: 0.6829


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 294/500, Loss: 0.5781, Accuracy: 0.7911
Validation Loss: 0.9512, Validation Accuracy: 0.6918


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 295/500, Loss: 0.5501, Accuracy: 0.7900
Validation Loss: 0.9597, Validation Accuracy: 0.6885


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 296/500, Loss: 0.5603, Accuracy: 0.8042
Validation Loss: 0.9578, Validation Accuracy: 0.6874


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 297/500, Loss: 0.5427, Accuracy: 0.8006
Validation Loss: 0.9602, Validation Accuracy: 0.6863


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 298/500, Loss: 0.5365, Accuracy: 0.8011
Validation Loss: 0.9609, Validation Accuracy: 0.6863


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 299/500, Loss: 0.5539, Accuracy: 0.8017
Validation Loss: 0.9648, Validation Accuracy: 0.6863


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 300/500, Loss: 0.5677, Accuracy: 0.7953
Validation Loss: 0.9685, Validation Accuracy: 0.6840


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 301/500, Loss: 0.5467, Accuracy: 0.8055
Validation Loss: 0.9610, Validation Accuracy: 0.6874


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 302/500, Loss: 0.5436, Accuracy: 0.7994
Validation Loss: 0.9713, Validation Accuracy: 0.6885


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 303/500, Loss: 0.5493, Accuracy: 0.7964
Validation Loss: 0.9654, Validation Accuracy: 0.6851


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 304/500, Loss: 0.5648, Accuracy: 0.7895
Validation Loss: 0.9514, Validation Accuracy: 0.6929


100%|██████████| 15/15 [01:05<00:00,  4.39s/it]


Epoch 305/500, Loss: 0.5612, Accuracy: 0.7920
Validation Loss: 0.9887, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:05<00:00,  4.39s/it]


Epoch 306/500, Loss: 0.5816, Accuracy: 0.7914
Validation Loss: 1.0256, Validation Accuracy: 0.6674


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 307/500, Loss: 0.5790, Accuracy: 0.7842
Validation Loss: 0.9958, Validation Accuracy: 0.6829


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 308/500, Loss: 0.5813, Accuracy: 0.7881
Validation Loss: 0.9539, Validation Accuracy: 0.6863


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 309/500, Loss: 0.5816, Accuracy: 0.7770
Validation Loss: 0.9258, Validation Accuracy: 0.6940


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 310/500, Loss: 0.5826, Accuracy: 0.7911
Validation Loss: 1.0000, Validation Accuracy: 0.6840


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 311/500, Loss: 0.5764, Accuracy: 0.7789
Validation Loss: 0.9554, Validation Accuracy: 0.7029


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 312/500, Loss: 0.5701, Accuracy: 0.7917
Validation Loss: 1.0934, Validation Accuracy: 0.6408


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 313/500, Loss: 0.5690, Accuracy: 0.7856
Validation Loss: 1.0158, Validation Accuracy: 0.6663


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 314/500, Loss: 0.5684, Accuracy: 0.7928
Validation Loss: 0.9842, Validation Accuracy: 0.7007


100%|██████████| 15/15 [01:05<00:00,  4.33s/it]


Epoch 315/500, Loss: 0.5982, Accuracy: 0.7834
Validation Loss: 1.0732, Validation Accuracy: 0.6541


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 316/500, Loss: 0.6270, Accuracy: 0.7584
Validation Loss: 1.0704, Validation Accuracy: 0.6663


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 317/500, Loss: 0.6690, Accuracy: 0.7576
Validation Loss: 1.0595, Validation Accuracy: 0.6630


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 318/500, Loss: 0.5873, Accuracy: 0.7800
Validation Loss: 0.9232, Validation Accuracy: 0.7106


100%|██████████| 15/15 [01:04<00:00,  4.32s/it]


Epoch 319/500, Loss: 0.5818, Accuracy: 0.7828
Validation Loss: 1.0632, Validation Accuracy: 0.6619


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 320/500, Loss: 0.6251, Accuracy: 0.7767
Validation Loss: 0.9722, Validation Accuracy: 0.6951


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 321/500, Loss: 0.6184, Accuracy: 0.7742
Validation Loss: 1.1004, Validation Accuracy: 0.6508


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 322/500, Loss: 0.6132, Accuracy: 0.7620
Validation Loss: 1.1867, Validation Accuracy: 0.6186


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 323/500, Loss: 0.7224, Accuracy: 0.7273
Validation Loss: 1.1493, Validation Accuracy: 0.6253


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 324/500, Loss: 0.6466, Accuracy: 0.7531
Validation Loss: 1.1600, Validation Accuracy: 0.6475


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 325/500, Loss: 0.6250, Accuracy: 0.7614
Validation Loss: 1.1456, Validation Accuracy: 0.6397


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 326/500, Loss: 0.6350, Accuracy: 0.7559
Validation Loss: 1.2452, Validation Accuracy: 0.6253


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 327/500, Loss: 0.6594, Accuracy: 0.7612
Validation Loss: 1.1485, Validation Accuracy: 0.6330


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 328/500, Loss: 0.6116, Accuracy: 0.7645
Validation Loss: 1.2825, Validation Accuracy: 0.5920


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 329/500, Loss: 0.6929, Accuracy: 0.7393
Validation Loss: 1.0733, Validation Accuracy: 0.6585


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 330/500, Loss: 0.6075, Accuracy: 0.7798
Validation Loss: 1.0931, Validation Accuracy: 0.6541


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 331/500, Loss: 0.6957, Accuracy: 0.7462
Validation Loss: 1.2646, Validation Accuracy: 0.6109


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 332/500, Loss: 0.6104, Accuracy: 0.7717
Validation Loss: 0.9995, Validation Accuracy: 0.6818


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 333/500, Loss: 0.5957, Accuracy: 0.7739
Validation Loss: 1.2191, Validation Accuracy: 0.6208


100%|██████████| 15/15 [01:04<00:00,  4.30s/it]


Epoch 334/500, Loss: 0.6813, Accuracy: 0.7393
Validation Loss: 0.9754, Validation Accuracy: 0.6752


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 335/500, Loss: 0.6101, Accuracy: 0.7667
Validation Loss: 1.0846, Validation Accuracy: 0.6519


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 336/500, Loss: 0.6234, Accuracy: 0.7756
Validation Loss: 1.0542, Validation Accuracy: 0.6696


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 337/500, Loss: 0.6142, Accuracy: 0.7767
Validation Loss: 1.1417, Validation Accuracy: 0.6353


100%|██████████| 15/15 [01:05<00:00,  4.33s/it]


Epoch 338/500, Loss: 0.6108, Accuracy: 0.7642
Validation Loss: 1.0438, Validation Accuracy: 0.6707


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 339/500, Loss: 0.5628, Accuracy: 0.7989
Validation Loss: 1.0142, Validation Accuracy: 0.6796


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 340/500, Loss: 0.5760, Accuracy: 0.7892
Validation Loss: 0.9477, Validation Accuracy: 0.6984


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 341/500, Loss: 0.5437, Accuracy: 0.7983
Validation Loss: 1.0614, Validation Accuracy: 0.6563


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 342/500, Loss: 0.5750, Accuracy: 0.7809
Validation Loss: 1.0472, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 343/500, Loss: 0.5560, Accuracy: 0.7961
Validation Loss: 1.0215, Validation Accuracy: 0.6874


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 344/500, Loss: 0.5689, Accuracy: 0.7859
Validation Loss: 0.9539, Validation Accuracy: 0.7051


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 345/500, Loss: 0.5354, Accuracy: 0.8050
Validation Loss: 0.9558, Validation Accuracy: 0.6929


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 346/500, Loss: 0.5518, Accuracy: 0.7953
Validation Loss: 1.0490, Validation Accuracy: 0.6785


100%|██████████| 15/15 [01:04<00:00,  4.32s/it]


Epoch 347/500, Loss: 0.6113, Accuracy: 0.7845
Validation Loss: 1.0031, Validation Accuracy: 0.6829


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 348/500, Loss: 0.5355, Accuracy: 0.8022
Validation Loss: 1.1117, Validation Accuracy: 0.6608


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 349/500, Loss: 0.5599, Accuracy: 0.7856
Validation Loss: 1.0184, Validation Accuracy: 0.6707


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 350/500, Loss: 0.5348, Accuracy: 0.8044
Validation Loss: 1.0083, Validation Accuracy: 0.7007


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 351/500, Loss: 0.5140, Accuracy: 0.8125
Validation Loss: 0.9718, Validation Accuracy: 0.7029


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 352/500, Loss: 0.5073, Accuracy: 0.8222
Validation Loss: 1.0777, Validation Accuracy: 0.6707


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 353/500, Loss: 0.5265, Accuracy: 0.8133
Validation Loss: 1.0583, Validation Accuracy: 0.6829


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 354/500, Loss: 0.4995, Accuracy: 0.8219
Validation Loss: 1.0423, Validation Accuracy: 0.6851


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 355/500, Loss: 0.4655, Accuracy: 0.8275
Validation Loss: 1.0483, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:05<00:00,  4.39s/it]


Epoch 356/500, Loss: 0.4865, Accuracy: 0.8158
Validation Loss: 1.0049, Validation Accuracy: 0.6962


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 357/500, Loss: 0.4798, Accuracy: 0.8286
Validation Loss: 0.9860, Validation Accuracy: 0.7062


100%|██████████| 15/15 [01:05<00:00,  4.39s/it]


Epoch 358/500, Loss: 0.4652, Accuracy: 0.8330
Validation Loss: 1.0168, Validation Accuracy: 0.7007


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 359/500, Loss: 0.4870, Accuracy: 0.8277
Validation Loss: 1.0061, Validation Accuracy: 0.6984


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 360/500, Loss: 0.4751, Accuracy: 0.8325
Validation Loss: 1.0219, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 361/500, Loss: 0.4770, Accuracy: 0.8297
Validation Loss: 1.0109, Validation Accuracy: 0.6940


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 362/500, Loss: 0.4709, Accuracy: 0.8355
Validation Loss: 1.0042, Validation Accuracy: 0.6962


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 363/500, Loss: 0.4666, Accuracy: 0.8350
Validation Loss: 1.0023, Validation Accuracy: 0.6973


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 364/500, Loss: 0.4621, Accuracy: 0.8300
Validation Loss: 1.0026, Validation Accuracy: 0.6973


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 365/500, Loss: 0.4820, Accuracy: 0.8294
Validation Loss: 1.0053, Validation Accuracy: 0.6962


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 366/500, Loss: 0.4677, Accuracy: 0.8336
Validation Loss: 1.0125, Validation Accuracy: 0.6951


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 367/500, Loss: 0.4674, Accuracy: 0.8388
Validation Loss: 0.9897, Validation Accuracy: 0.7007


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 368/500, Loss: 0.4670, Accuracy: 0.8358
Validation Loss: 0.9998, Validation Accuracy: 0.6962


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 369/500, Loss: 0.4966, Accuracy: 0.8219
Validation Loss: 1.0442, Validation Accuracy: 0.6807


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 370/500, Loss: 0.4748, Accuracy: 0.8305
Validation Loss: 1.0147, Validation Accuracy: 0.6885


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 371/500, Loss: 0.4801, Accuracy: 0.8311
Validation Loss: 1.0069, Validation Accuracy: 0.6840


100%|██████████| 15/15 [01:04<00:00,  4.32s/it]


Epoch 372/500, Loss: 0.4644, Accuracy: 0.8291
Validation Loss: 1.0278, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 373/500, Loss: 0.4606, Accuracy: 0.8269
Validation Loss: 0.9764, Validation Accuracy: 0.7162


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 374/500, Loss: 0.4884, Accuracy: 0.8236
Validation Loss: 0.9829, Validation Accuracy: 0.7162


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 375/500, Loss: 0.4793, Accuracy: 0.8291
Validation Loss: 1.1167, Validation Accuracy: 0.6796


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 376/500, Loss: 0.5063, Accuracy: 0.8161
Validation Loss: 0.9816, Validation Accuracy: 0.7084


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 377/500, Loss: 0.5278, Accuracy: 0.8147
Validation Loss: 1.0747, Validation Accuracy: 0.6863


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 378/500, Loss: 0.5041, Accuracy: 0.8266
Validation Loss: 0.9643, Validation Accuracy: 0.7173


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 379/500, Loss: 0.4864, Accuracy: 0.8219
Validation Loss: 0.9487, Validation Accuracy: 0.7217


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 380/500, Loss: 0.4774, Accuracy: 0.8247
Validation Loss: 1.1115, Validation Accuracy: 0.6718


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 381/500, Loss: 0.5239, Accuracy: 0.8117
Validation Loss: 0.9470, Validation Accuracy: 0.6984


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 382/500, Loss: 0.5276, Accuracy: 0.8067
Validation Loss: 1.1443, Validation Accuracy: 0.6519


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 383/500, Loss: 0.4880, Accuracy: 0.8164
Validation Loss: 1.0264, Validation Accuracy: 0.6829


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 384/500, Loss: 0.5233, Accuracy: 0.8128
Validation Loss: 1.0827, Validation Accuracy: 0.6663


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 385/500, Loss: 0.5139, Accuracy: 0.8139
Validation Loss: 1.0124, Validation Accuracy: 0.6907


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 386/500, Loss: 0.5276, Accuracy: 0.8153
Validation Loss: 1.0090, Validation Accuracy: 0.6918


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 387/500, Loss: 0.5330, Accuracy: 0.8086
Validation Loss: 0.9660, Validation Accuracy: 0.7173


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 388/500, Loss: 0.5322, Accuracy: 0.8108
Validation Loss: 1.0480, Validation Accuracy: 0.6729


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 389/500, Loss: 0.5306, Accuracy: 0.8028
Validation Loss: 0.9906, Validation Accuracy: 0.6962


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 390/500, Loss: 0.5255, Accuracy: 0.8075
Validation Loss: 1.1327, Validation Accuracy: 0.6741


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 391/500, Loss: 0.5346, Accuracy: 0.8061
Validation Loss: 0.9248, Validation Accuracy: 0.7118


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 392/500, Loss: 0.5351, Accuracy: 0.8086
Validation Loss: 1.0641, Validation Accuracy: 0.6785


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 393/500, Loss: 0.5494, Accuracy: 0.8000
Validation Loss: 1.0653, Validation Accuracy: 0.6818


100%|██████████| 15/15 [01:05<00:00,  4.35s/it]


Epoch 394/500, Loss: 0.5421, Accuracy: 0.8083
Validation Loss: 1.1518, Validation Accuracy: 0.6364


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 395/500, Loss: 0.5281, Accuracy: 0.7992
Validation Loss: 1.1777, Validation Accuracy: 0.6630


100%|██████████| 15/15 [01:05<00:00,  4.36s/it]


Epoch 396/500, Loss: 0.6077, Accuracy: 0.7623
Validation Loss: 0.9762, Validation Accuracy: 0.7062


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 397/500, Loss: 0.5568, Accuracy: 0.8075
Validation Loss: 0.9974, Validation Accuracy: 0.7040


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 398/500, Loss: 0.5542, Accuracy: 0.7947
Validation Loss: 1.0370, Validation Accuracy: 0.6763


100%|██████████| 15/15 [01:04<00:00,  4.33s/it]


Epoch 399/500, Loss: 0.5844, Accuracy: 0.7939
Validation Loss: 0.9407, Validation Accuracy: 0.7129


100%|██████████| 15/15 [01:05<00:00,  4.38s/it]


Epoch 400/500, Loss: 0.5015, Accuracy: 0.8031
Validation Loss: 1.0167, Validation Accuracy: 0.6896


100%|██████████| 15/15 [01:06<00:00,  4.40s/it]


Epoch 401/500, Loss: 0.5196, Accuracy: 0.8064
Validation Loss: 1.0167, Validation Accuracy: 0.6996


100%|██████████| 15/15 [01:04<00:00,  4.31s/it]


Epoch 402/500, Loss: 0.4875, Accuracy: 0.8277
Validation Loss: 0.9658, Validation Accuracy: 0.6973


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 403/500, Loss: 0.5000, Accuracy: 0.8164
Validation Loss: 1.0324, Validation Accuracy: 0.6874


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 404/500, Loss: 0.5171, Accuracy: 0.8092
Validation Loss: 1.2213, Validation Accuracy: 0.6552


100%|██████████| 15/15 [01:05<00:00,  4.39s/it]


Epoch 405/500, Loss: 0.5101, Accuracy: 0.8053
Validation Loss: 1.1166, Validation Accuracy: 0.6452


100%|██████████| 15/15 [01:05<00:00,  4.37s/it]


Epoch 406/500, Loss: 0.5039, Accuracy: 0.8117
Validation Loss: 1.0961, Validation Accuracy: 0.6696


100%|██████████| 15/15 [01:05<00:00,  4.34s/it]


Epoch 407/500, Loss: 0.4931, Accuracy: 0.8211
Validation Loss: 0.9566, Validation Accuracy: 0.7029


 67%|██████▋   | 10/15 [00:46<00:23,  4.67s/it]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "best_model_5gest")

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob

# Function to extract the participant ID from the filename
# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'participant' in part:
#             participant_id = part.replace('participant', '')
#             return int(participant_id)  # Convert to integer
#     return None  # If no participant ID found

# # Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.3, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTForImageClassification, ViTConfig
from tqdm import tqdm

# Define the number of classes (in this case, the number of participants)
num_classes = len(np.unique(labels))

config = ViTConfig(
    num_hidden_layers=6,  # Reduced number of layers
    hidden_size=512,      # Smaller hidden size
    num_attention_heads=4,  # Fewer attention heads
    patch_size=16,        # Smaller patch size
    intermediate_size=1024,  # Reduced intermediate layer size
    image_size=224,       # Input image size
    num_labels=num_classes,         # Number of output classes
    hidden_dropout_prob=0.3,  # Adding dropout to regularize
    attention_probs_dropout_prob=0.3
)

# Initialize ViT model from scratch with this config
model = ViTForImageClassification(config)

# Define optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 30  # Set the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader):
        # Move data to the same device as the model
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images).logits
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # Evaluation after every epoch (optional)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy after epoch {epoch+1}: {accuracy:.2f}%")

# Save the trained model
torch.save(model.state_dict(), 'vit_emg_model.pth')


100%|██████████| 99/99 [01:39<00:00,  1.01s/it]


Epoch [1/30], Loss: 2.3801
Test Accuracy after epoch 1: 9.61%


100%|██████████| 99/99 [01:39<00:00,  1.00s/it]


Epoch [2/30], Loss: 2.3265
Test Accuracy after epoch 2: 13.09%


100%|██████████| 99/99 [01:39<00:00,  1.00s/it]


Epoch [3/30], Loss: 2.3073
Test Accuracy after epoch 3: 14.05%


100%|██████████| 99/99 [01:39<00:00,  1.00s/it]


Epoch [4/30], Loss: 2.2720


KeyboardInterrupt: 

In [None]:
print(f"Model output classes: {model.config.num_labels}")
print(f"Unique labels in dataset: {np.unique(train_labels)}")

Model output classes: 10
Unique labels in dataset: [0 1 2 3 4 5 6 7 8 9]


In [None]:
# --- Training loop for 50 epochs ---
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 1, Loss: 2.2416, Accuracy: 0.1794
Test accuracy: 0.2553


Epoch 2/50: 100%|██████████| 90/90 [01:12<00:00,  1.25it/s]


Epoch 2, Loss: 1.8933, Accuracy: 0.3556
Test accuracy: 0.3955


Epoch 3/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 3, Loss: 1.6035, Accuracy: 0.4693
Test accuracy: 0.4320


Epoch 4/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 4, Loss: 1.3069, Accuracy: 0.5781
Test accuracy: 0.4600


Epoch 5/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 5, Loss: 1.0768, Accuracy: 0.6511
Test accuracy: 0.5470


Epoch 6/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 6, Loss: 0.6955, Accuracy: 0.7943
Test accuracy: 0.5330


Epoch 7/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 7, Loss: 0.4704, Accuracy: 0.8635
Test accuracy: 0.5259


Epoch 8/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 8, Loss: 0.3046, Accuracy: 0.9210
Test accuracy: 0.5498


Epoch 9/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 9, Loss: 0.2185, Accuracy: 0.9438
Test accuracy: 0.5512


Epoch 10/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 10, Loss: 0.1618, Accuracy: 0.9582
Test accuracy: 0.4867


Epoch 11/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 11, Loss: 0.1546, Accuracy: 0.9596
Test accuracy: 0.5456


Epoch 12/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 12, Loss: 0.1435, Accuracy: 0.9603
Test accuracy: 0.5652


Epoch 13/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 13, Loss: 0.0731, Accuracy: 0.9849
Test accuracy: 0.4923


Epoch 14/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 14, Loss: 0.1267, Accuracy: 0.9645
Test accuracy: 0.4979


Epoch 15/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 15, Loss: 0.0674, Accuracy: 0.9835
Test accuracy: 0.5302


Epoch 16/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 16, Loss: 0.0767, Accuracy: 0.9796
Test accuracy: 0.5666


Epoch 17/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 17, Loss: 0.0595, Accuracy: 0.9849
Test accuracy: 0.5891


Epoch 18/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 18, Loss: 0.0322, Accuracy: 0.9940
Test accuracy: 0.5428


Epoch 19/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 19, Loss: 0.0548, Accuracy: 0.9839
Test accuracy: 0.5680


Epoch 20/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 20, Loss: 0.0505, Accuracy: 0.9856
Test accuracy: 0.5666


Epoch 21/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 21, Loss: 0.0864, Accuracy: 0.9744
Test accuracy: 0.5722


Epoch 22/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 22, Loss: 0.1016, Accuracy: 0.9698
Test accuracy: 0.4769


Epoch 23/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 23, Loss: 0.0759, Accuracy: 0.9772
Test accuracy: 0.5330


Epoch 24/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 24, Loss: 0.0339, Accuracy: 0.9933
Test accuracy: 0.5498


Epoch 25/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 25, Loss: 0.1114, Accuracy: 0.9705
Test accuracy: 0.5652


Epoch 26/50: 100%|██████████| 90/90 [01:11<00:00,  1.27it/s]


Epoch 26, Loss: 0.0147, Accuracy: 0.9972
Test accuracy: 0.5863


Epoch 27/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 27, Loss: 0.0083, Accuracy: 0.9993
Test accuracy: 0.5806


Epoch 28/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 28, Loss: 0.0066, Accuracy: 0.9993
Test accuracy: 0.5849


Epoch 29/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 29, Loss: 0.0051, Accuracy: 0.9996
Test accuracy: 0.5849


Epoch 30/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 30, Loss: 0.0035, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 31/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 31, Loss: 0.0031, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 32/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 32, Loss: 0.0028, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 33/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 33, Loss: 0.0026, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 34/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 34, Loss: 0.0024, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 35/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 35, Loss: 0.0022, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 36/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 36, Loss: 0.0021, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 37/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 37, Loss: 0.0019, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 38/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 38, Loss: 0.0018, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 39/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 39, Loss: 0.0017, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 40/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 40, Loss: 0.0016, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 41/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 41, Loss: 0.0015, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 42/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 42, Loss: 0.0014, Accuracy: 1.0000
Test accuracy: 0.5933


Epoch 43/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 43, Loss: 0.0013, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 44/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 44, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 45/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 45, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 46/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 46, Loss: 0.0011, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 47/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 47, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 48/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 48, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 49/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 49, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 50/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 50, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5975


In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 45/45 [00:36<00:00,  1.25it/s]


Epoch 1, Loss: 1.4741, Accuracy: 0.3455
Test accuracy: 0.4566


Epoch 2/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 2, Loss: 1.1747, Accuracy: 0.5407
Test accuracy: 0.5770


Epoch 3/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 3, Loss: 0.8848, Accuracy: 0.6713
Test accuracy: 0.6751


Epoch 4/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 4, Loss: 0.6088, Accuracy: 0.8062
Test accuracy: 0.7171


Epoch 5/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 5, Loss: 0.4332, Accuracy: 0.8581
Test accuracy: 0.6303


Epoch 6/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 6, Loss: 0.2952, Accuracy: 0.9094
Test accuracy: 0.7423


Epoch 7/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 7, Loss: 0.1660, Accuracy: 0.9565
Test accuracy: 0.7171


Epoch 8/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 8, Loss: 0.1073, Accuracy: 0.9719
Test accuracy: 0.7003


Epoch 9/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 9, Loss: 0.1327, Accuracy: 0.9621
Test accuracy: 0.7283


Epoch 10/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 10, Loss: 0.0993, Accuracy: 0.9747
Test accuracy: 0.7563


Epoch 11/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 11, Loss: 0.0724, Accuracy: 0.9831
Test accuracy: 0.6947


Epoch 12/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 12, Loss: 0.0463, Accuracy: 0.9902
Test accuracy: 0.7115


Epoch 13/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 13, Loss: 0.0474, Accuracy: 0.9881
Test accuracy: 0.7283


Epoch 14/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 14, Loss: 0.0936, Accuracy: 0.9747
Test accuracy: 0.7115


Epoch 15/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 15, Loss: 0.1236, Accuracy: 0.9621
Test accuracy: 0.6499


Epoch 16/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 16, Loss: 0.1496, Accuracy: 0.9537
Test accuracy: 0.6779


Epoch 17/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 17, Loss: 0.0414, Accuracy: 0.9909
Test accuracy: 0.7619


Epoch 18/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 18, Loss: 0.0188, Accuracy: 0.9965
Test accuracy: 0.7759


Epoch 19/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 19, Loss: 0.0172, Accuracy: 0.9986
Test accuracy: 0.7731


Epoch 20/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 20, Loss: 0.0108, Accuracy: 0.9993
Test accuracy: 0.7675


In [None]:
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data    

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 113/113 [03:05<00:00,  1.64s/it]


Epoch 1, Loss: 1.0156, Accuracy: 0.5933
Test accuracy: 0.6951


Epoch 2/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 2, Loss: 0.6677, Accuracy: 0.7506
Test accuracy: 0.7816


Epoch 3/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 3, Loss: 0.4857, Accuracy: 0.8227
Test accuracy: 0.7627


Epoch 4/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 4, Loss: 0.3735, Accuracy: 0.8713
Test accuracy: 0.8104


Epoch 5/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 5, Loss: 0.2388, Accuracy: 0.9223
Test accuracy: 0.7938


Epoch 6/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 6, Loss: 0.1578, Accuracy: 0.9503
Test accuracy: 0.7993


Epoch 7/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 7, Loss: 0.1318, Accuracy: 0.9603
Test accuracy: 0.7860


Epoch 8/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 8, Loss: 0.0938, Accuracy: 0.9734
Test accuracy: 0.8226


Epoch 9/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 9, Loss: 0.0645, Accuracy: 0.9803
Test accuracy: 0.7672


Epoch 10/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 10, Loss: 0.0746, Accuracy: 0.9761
Test accuracy: 0.7960


Epoch 11/20: 100%|██████████| 113/113 [01:29<00:00,  1.26it/s]


Epoch 11, Loss: 0.0330, Accuracy: 0.9917
Test accuracy: 0.8149


Epoch 12/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 12, Loss: 0.0445, Accuracy: 0.9889
Test accuracy: 0.8104


Epoch 13/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 13, Loss: 0.0661, Accuracy: 0.9792
Test accuracy: 0.8115


Epoch 14/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 14, Loss: 0.0471, Accuracy: 0.9850
Test accuracy: 0.8259


Epoch 15/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 15, Loss: 0.0220, Accuracy: 0.9945
Test accuracy: 0.8082


Epoch 16/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 16, Loss: 0.0149, Accuracy: 0.9969
Test accuracy: 0.8193


Epoch 17/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 17, Loss: 0.0469, Accuracy: 0.9859
Test accuracy: 0.7971


Epoch 18/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 18, Loss: 0.0401, Accuracy: 0.9870
Test accuracy: 0.7716


Epoch 19/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 19, Loss: 0.0242, Accuracy: 0.9931
Test accuracy: 0.8049


Epoch 20/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 20, Loss: 0.0091, Accuracy: 0.9978
Test accuracy: 0.8060


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 226/226 [06:29<00:00,  1.72s/it]


Epoch 1, Loss: 1.6487, Accuracy: 0.4242
Test accuracy: 0.5774


Epoch 2/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 2, Loss: 1.1021, Accuracy: 0.6264
Test accuracy: 0.6051


Epoch 3/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 3, Loss: 0.8969, Accuracy: 0.6896
Test accuracy: 0.6794


Epoch 4/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 4, Loss: 0.6698, Accuracy: 0.7712
Test accuracy: 0.6589


Epoch 5/20: 100%|██████████| 226/226 [02:58<00:00,  1.26it/s]


Epoch 5, Loss: 0.5011, Accuracy: 0.8303
Test accuracy: 0.6744


Epoch 6/20: 100%|██████████| 226/226 [02:58<00:00,  1.27it/s]


Epoch 6, Loss: 0.3302, Accuracy: 0.8935
Test accuracy: 0.6384


Epoch 7/20:  92%|█████████▏| 208/226 [02:45<00:14,  1.26it/s]


KeyboardInterrupt: 

In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import glob

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

In [None]:

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_20part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
all_labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)
print(f"Shape of X_train: {X_train.shape}")

Shape of X_train: (5701, 327680)


In [None]:
pca = PCA(n_components = 2000)
# Step 2: Fit PCA on the dataset
pca.fit(X_train)
# Step 3: Calculate the cumulative variance explained by the components
cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
# Step 4: Find the number of components that explain at least 95% variance
n_components_95 = np.argmax(cumulative_variance >= 0.95) + 1

print(f"Number of components to retain 95% of variance: {n_components_95}")

# Print explained variance for each component
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance by each component: {explained_variance}")

# Calculate cumulative explained variance
cumulative_variance = np.cumsum(explained_variance)
print(f"Cumulative explained variance: {cumulative_variance}")

Number of components to retain 95% of variance: 1
Explained variance by each component: [0.02734777 0.02461403 0.00818816 0.0069905  0.00584958 0.00445974
 0.00403911 0.00369412 0.00349787 0.00316976 0.00300401 0.00289899
 0.00279481 0.00275903 0.00264407 0.00261821 0.00259447 0.00257349
 0.00248961 0.00239943 0.0023721  0.00232195 0.00227801 0.00223757
 0.0022159  0.00220089 0.00212229 0.00206722 0.00201354 0.00198255
 0.00194987 0.00192753 0.00191069 0.00189586 0.00187392 0.00184805
 0.00183598 0.00181051 0.00178506 0.00175982 0.0017539  0.00172137
 0.00170902 0.00169332 0.00167088 0.00165842 0.00165174 0.0016362
 0.00163023 0.00160764 0.00159447 0.00155964 0.00155086 0.00152713
 0.00146956 0.00145834 0.00141764 0.00141322 0.00140813 0.00138163
 0.00136974 0.00136623 0.00134497 0.00133781 0.00131482 0.00130781
 0.00129883 0.00128668 0.00127945 0.00126588 0.00125526 0.00123351
 0.00122743 0.00121431 0.00120742 0.00119956 0.00119493 0.00118896
 0.0011854  0.00117166 0.00116383 0.001154

In [None]:


# Initialize PCA for dimensionality reduction
#pca = PCA(n_components=4)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.3065


In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

# Specify session folders
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']

# Target length for EMG signals (adjust based on your data)
target_length = 10240

# Collect all .mat files and their corresponding labels
all_mat_files = []
all_labels = []

for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=100)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.4065


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, GridSearchCV
from tqdm import tqdm

# Function to extract label from filename
def extract_label_from_filename(filename):
    basename = os.path.basename(filename)
    return int(basename.split('_')[2].replace('gesture', ''))

# Function to process a batch of files
def process_batch(batch_files, target_length):
    data_emg_list = []
    labels_list = []
    for mat_file in tqdm(batch_files, desc="Processing Batch"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(mat_file)
            data_emg = mat_data['data_emg']
            label = extract_label_from_filename(mat_file)

            # Handle NaNs and Infinities
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Pad or truncate the data to the target length
            if data.shape[0] < target_length:
                padded_data = np.pad(data, ((0, target_length - data.shape[0]), (0, 0)), 'constant', constant_values=0)
            else:
                padded_data = data[:target_length]
            
            # Append to lists
            data_emg_list.append(padded_data)
            labels_list.append(label)
    
    # Convert to numpy arrays
    X_batch = np.array(data_emg_list)
    y_batch = np.array(labels_list)
    
    # Reshape the data
    X_batch_flat = X_batch.reshape(X_batch.shape[0], -1)
    
    return X_batch_flat, y_batch

# Path to the data
target_length = 10240  # Desired length for each EMG signal
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']
all_mat_files = []
all_labels = []

# Collect all .mat files and their corresponding labels
for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=4)  # Reduce to 100 components, adjust as needed

# Initialize an SVM with stochastic gradient descent (SGD)
svm_model = SGDClassifier(loss='hinge', penalty='l2', max_iter=1000)

# Process training data in batches
batch_size = 50

for i in range(0, len(train_files), batch_size):
    batch_files = train_files[i:i+batch_size]
    X_batch, y_batch = process_batch(batch_files, target_length)
    
    # Apply PCA
    X_batch_pca = pca.fit_transform(X_batch)
    
    # Incrementally fit the SVM model, passing the complete set of classes every time
    svm_model.partial_fit(X_batch_pca, y_batch, classes=np.unique(train_labels))

# Process and test the model on test data
X_test, y_test = process_batch(test_files, target_length)
X_test_pca = pca.transform(X_test)

# Inference on test data
y_pred = svm_model.predict(X_test_pca)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# --- Hyperparameter Tuning with GridSearchCV ---

# Define the parameter grid for tuning
param_grid = {
    'alpha': [0.0001, 0.001, 0.01],
    'loss': ['hinge', 'log'],
    'penalty': ['l2', 'l1', 'elasticnet'],
    'max_iter': [1000, 2000, 3000]
}

# Initialize GridSearchCV with SGDClassifier
grid_search = GridSearchCV(SGDClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)

# Fit grid search
grid_search.fit(X_test_pca, y_test)

# Print best parameters and score
print(f"Best parameters found: {grid_search.best_params_}")
print(f"Best cross-validation accuracy: {grid_search.best_score_ * 100:.2f}%")


Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.05it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.17it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.18it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.02it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.53it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.70it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.16it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.64it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.37it/s]
Processing

Test Accuracy: 5.94%
Fitting 5 folds for each of 54 candidates, totalling 270 fits
Best parameters found: {'alpha': 0.01, 'loss': 'hinge', 'max_iter': 3000, 'penalty': 'l2'}
Best cross-validation accuracy: 11.03%


135 fits failed out of a total of 270.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
7 fits failed with the following error:
Traceback (most recent call last):
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 1466, in wrapper
    estimator._validate_params()
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packag

In [None]:
import os

def extract_label_from_filename(filepath):
    # Get the base name of the file (e.g., 'session1_participant1_gesture10_trial1.dat')
    filename = os.path.basename(filepath)
    
    # Split the filename by underscores
    parts = filename.split('_')
    
    # Find the part that contains 'gesture' and extract the gesture number
    for part in parts:
        if 'gesture' in part:
            # Assuming the format is 'gesture<number>', we can remove 'gesture' and get the number
            gesture_number = part.replace('gesture', '')
            return int(gesture_number)  # Convert to integer if needed
    
    return None  # If no gesture label found


In [None]:
import os
import scipy.io as sio
import numpy as np
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm  # Import tqdm for progress bars

# Path where the .mat files are stored
#data_folder = 'Session1'

# Initialize lists to hold data
data_emg_list = []
labels_list = []  # Assuming you have labels corresponding to the gestures

# Loop through each .mat file in the folder with a progress bar
for session_num in range(1, 3):  # Looping through sessions (1 to 3)
    session_folder = os.path.join(f'Session{session_num}_matfiles')
    
    # Get all .mat files in the current session folder
    mat_files = sorted(os.listdir(session_folder))
    
    # Use tqdm to show progress
    for mat_file in tqdm(mat_files, desc=f"Processing Session {session_num}"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(os.path.join(session_folder, mat_file))
            
            # Extract the EMG data (you may need to adjust variable names if different)
            data_emg = mat_data['data_emg']  # This should match the name in your .mat files
            label = extract_label_from_filename(mat_file)  # You would need to define this function
            # Check for NaNs and infinities in the loaded data
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            # Append the data and label to lists
            data_emg_list.append(data)
            labels_list.append(label)


Processing Session 1: 100%|██████████| 5107/5107 [02:12<00:00, 38.46it/s]
Processing Session 2: 100%|██████████| 5109/5109 [02:24<00:00, 35.36it/s]


In [None]:
target_length = 10240  # The desired length for all samples
padded_data_emg_list = [
    np.pad(emg_data, ((0, target_length - emg_data.shape[0]), (0, 0)), 'constant', constant_values=0)
    if emg_data.shape[0] < target_length else emg_data
    for emg_data in data_emg_list
]

In [None]:

# Convert lists to numpy arrays
X = np.array(padded_data_emg_list)  # Data (e.g., EMG signals)
y = np.array(labels_list)  # Labels (e.g., gesture IDs)

# Flatten the data if needed (depends on the structure of your data)
X_flat = X.reshape(X.shape[0], -1)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_flat, y, test_size=0.2, random_state=42)

# Apply PCA
print("Applying PCA...")
pca = PCA(n_components=50)  # Adjust the number of components as needed
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# Train an SVM model with progress tracking
print("Training SVM...")
svm_model = SVC(kernel='linear')

# You can wrap the fitting process with tqdm to monitor the training as well
svm_model.fit(X_train_pca, y_train)

# Make predictions and evaluate the model
y_pred = svm_model.predict(X_test_pca)
accuracy = accuracy_score(y_test, y_pred)
print(f'PCA + SVM Accuracy: {accuracy:.4f}')

# Optionally: Use X_train_pca as input to a Vision Transformer model
# You can now use the X_train_pca for further processing or to feed into a Vision Transformer model.


Applying PCA...


In [None]:
# nan_count = np.isnan(data_emg).sum()
# inf_count = np.isinf(data_emg).sum()
# print(f"NaNs found: {nan_count}, Infinities found: {inf_count}")