In [1]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel
import mlflow
import mlflow.pytorch

root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path
num_epochs = 500

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
    
#     for part in parts:
#         if 'participant' in part:
#             participant_number = part.replace('participant', '')
#             return int(participant_number)  # Convert to integer
#     return None  # If no participant label found

def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model_10gest.pth'
last_model_path = 'last_model_10gest.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    mlflow.log_metric("train_loss", epoch_loss, step=epoch)
    mlflow.log_metric("train_accuracy", epoch_acc, step=epoch)

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    mlflow.log_metric("val_loss", val_loss, step=epoch)
    mlflow.log_metric("val_accuracy", val_acc, step=epoch)

    torch.save(model.state_dict(), last_model_path)
    mlflow.pytorch.log_model(model, "models/last_model")
    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')
        mlflow.pytorch.log_model(model, "models/best_model")

end_mlflow_experiment()        

Validation Loss: 1.6370, Validation Accuracy: 0.5702


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 452/500, Loss: 0.6301, Accuracy: 0.7702
Validation Loss: 1.5770, Validation Accuracy: 0.5774


100%|██████████| 29/29 [02:20<00:00,  4.84s/it]


Epoch 453/500, Loss: 0.5942, Accuracy: 0.7827
Validation Loss: 1.6848, Validation Accuracy: 0.5724


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 454/500, Loss: 0.5793, Accuracy: 0.7903
Validation Loss: 1.5963, Validation Accuracy: 0.5763


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 455/500, Loss: 0.5888, Accuracy: 0.7851
Validation Loss: 1.4837, Validation Accuracy: 0.6029


100%|██████████| 29/29 [02:20<00:00,  4.83s/it]


Epoch 456/500, Loss: 0.5691, Accuracy: 0.7923
Validation Loss: 1.6231, Validation Accuracy: 0.5724


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 457/500, Loss: 0.5600, Accuracy: 0.7970
Validation Loss: 1.5940, Validation Accuracy: 0.5818


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 458/500, Loss: 0.5571, Accuracy: 0.8020
Validation Loss: 1.5625, Validation Accuracy: 0.5923


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 459/500, Loss: 0.5372, Accuracy: 0.8027
Validation Loss: 1.5738, Validation Accuracy: 0.5879


100%|██████████| 29/29 [02:20<00:00,  4.84s/it]


Epoch 460/500, Loss: 0.5435, Accuracy: 0.7993
Validation Loss: 1.5707, Validation Accuracy: 0.5890


100%|██████████| 29/29 [02:20<00:00,  4.84s/it]


Epoch 461/500, Loss: 0.5326, Accuracy: 0.8068
Validation Loss: 1.5738, Validation Accuracy: 0.5874


100%|██████████| 29/29 [02:20<00:00,  4.86s/it]


Epoch 462/500, Loss: 0.5363, Accuracy: 0.8047
Validation Loss: 1.5696, Validation Accuracy: 0.5868


100%|██████████| 29/29 [02:19<00:00,  4.83s/it]


Epoch 463/500, Loss: 0.5389, Accuracy: 0.8035
Validation Loss: 1.5649, Validation Accuracy: 0.5840


100%|██████████| 29/29 [02:20<00:00,  4.84s/it]


Epoch 464/500, Loss: 0.5333, Accuracy: 0.8047
Validation Loss: 1.5557, Validation Accuracy: 0.5829


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 465/500, Loss: 0.5495, Accuracy: 0.8034
Validation Loss: 1.5791, Validation Accuracy: 0.5885


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 466/500, Loss: 0.5484, Accuracy: 0.7967
Validation Loss: 1.5420, Validation Accuracy: 0.5962


100%|██████████| 29/29 [02:20<00:00,  4.83s/it]


Epoch 467/500, Loss: 0.5677, Accuracy: 0.7960
Validation Loss: 1.5587, Validation Accuracy: 0.5868


100%|██████████| 29/29 [02:20<00:00,  4.86s/it]


Epoch 468/500, Loss: 0.5535, Accuracy: 0.7936
Validation Loss: 1.6182, Validation Accuracy: 0.5729


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 469/500, Loss: 0.5719, Accuracy: 0.7906
Validation Loss: 1.5138, Validation Accuracy: 0.5946


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 470/500, Loss: 0.5873, Accuracy: 0.7875
Validation Loss: 1.6278, Validation Accuracy: 0.5763


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 471/500, Loss: 0.6047, Accuracy: 0.7798
Validation Loss: 1.7082, Validation Accuracy: 0.5596


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 472/500, Loss: 0.6182, Accuracy: 0.7695
Validation Loss: 1.6509, Validation Accuracy: 0.5729


100%|██████████| 29/29 [02:22<00:00,  4.90s/it]


Epoch 473/500, Loss: 0.6155, Accuracy: 0.7706
Validation Loss: 1.7083, Validation Accuracy: 0.5641


100%|██████████| 29/29 [02:22<00:00,  4.91s/it]


Epoch 474/500, Loss: 0.6331, Accuracy: 0.7687
Validation Loss: 1.7257, Validation Accuracy: 0.5585


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 475/500, Loss: 0.6428, Accuracy: 0.7681
Validation Loss: 1.6216, Validation Accuracy: 0.5696


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 476/500, Loss: 0.6135, Accuracy: 0.7777
Validation Loss: 1.5692, Validation Accuracy: 0.5946


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 477/500, Loss: 0.6349, Accuracy: 0.7683
Validation Loss: 1.5397, Validation Accuracy: 0.6029


100%|██████████| 29/29 [02:20<00:00,  4.85s/it]


Epoch 478/500, Loss: 0.6333, Accuracy: 0.7658
Validation Loss: 1.6883, Validation Accuracy: 0.5679


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 479/500, Loss: 0.6499, Accuracy: 0.7552
Validation Loss: 1.5453, Validation Accuracy: 0.5790


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 480/500, Loss: 0.6417, Accuracy: 0.7601
Validation Loss: 1.6893, Validation Accuracy: 0.5663


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 481/500, Loss: 0.6394, Accuracy: 0.7647
Validation Loss: 1.5919, Validation Accuracy: 0.5630


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 482/500, Loss: 0.6330, Accuracy: 0.7704
Validation Loss: 1.6656, Validation Accuracy: 0.5663


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 483/500, Loss: 0.6437, Accuracy: 0.7602
Validation Loss: 1.6409, Validation Accuracy: 0.5829


100%|██████████| 29/29 [02:21<00:00,  4.86s/it]


Epoch 484/500, Loss: 0.6018, Accuracy: 0.7837
Validation Loss: 1.6875, Validation Accuracy: 0.5674


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 485/500, Loss: 0.6180, Accuracy: 0.7715
Validation Loss: 1.6340, Validation Accuracy: 0.5663


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 486/500, Loss: 0.5961, Accuracy: 0.7871
Validation Loss: 1.6739, Validation Accuracy: 0.5774


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 487/500, Loss: 0.5805, Accuracy: 0.7864
Validation Loss: 1.6417, Validation Accuracy: 0.5630


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 488/500, Loss: 0.5639, Accuracy: 0.7918
Validation Loss: 1.6081, Validation Accuracy: 0.5785


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 489/500, Loss: 0.5597, Accuracy: 0.7989
Validation Loss: 1.6587, Validation Accuracy: 0.5790


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 490/500, Loss: 0.5440, Accuracy: 0.8035
Validation Loss: 1.5637, Validation Accuracy: 0.5918


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 491/500, Loss: 0.5206, Accuracy: 0.8057
Validation Loss: 1.5179, Validation Accuracy: 0.5907


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 492/500, Loss: 0.5077, Accuracy: 0.8163
Validation Loss: 1.5961, Validation Accuracy: 0.5862


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 493/500, Loss: 0.5218, Accuracy: 0.8107
Validation Loss: 1.6076, Validation Accuracy: 0.5851


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 494/500, Loss: 0.5038, Accuracy: 0.8157
Validation Loss: 1.5890, Validation Accuracy: 0.5890


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 495/500, Loss: 0.5046, Accuracy: 0.8161
Validation Loss: 1.5916, Validation Accuracy: 0.5918


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 496/500, Loss: 0.5091, Accuracy: 0.8153
Validation Loss: 1.5894, Validation Accuracy: 0.5912


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 497/500, Loss: 0.5000, Accuracy: 0.8189
Validation Loss: 1.6005, Validation Accuracy: 0.5890


100%|██████████| 29/29 [02:21<00:00,  4.89s/it]


Epoch 498/500, Loss: 0.5060, Accuracy: 0.8140
Validation Loss: 1.5837, Validation Accuracy: 0.5896


100%|██████████| 29/29 [02:21<00:00,  4.87s/it]


Epoch 499/500, Loss: 0.5123, Accuracy: 0.8131
Validation Loss: 1.6212, Validation Accuracy: 0.5918


100%|██████████| 29/29 [02:21<00:00,  4.88s/it]


Epoch 500/500, Loss: 0.5210, Accuracy: 0.8115
Validation Loss: 1.5613, Validation Accuracy: 0.5929




In [3]:
print(best_val_acc)

0.6039933444259568
