In [1]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel
import mlflow
import mlflow.pytorch

root_folder = 's1_s2_s3_matfiles_5part'  # Update with your folder path
num_epochs = 500

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

def extract_participant_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    
    for part in parts:
        if 'participant' in part:
            participant_number = part.replace('participant', '')
            return int(participant_number)  # Convert to integer
    return None  # If no participant label found

# def extract_label_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'gesture' in part:
#             return int(part.replace('gesture', ''))
#     return None

class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model_5part.pth'
last_model_path = 'last_model_5part.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    mlflow.log_metric("train_loss", epoch_loss, step=epoch)
    mlflow.log_metric("train_accuracy", epoch_acc, step=epoch)

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    mlflow.log_metric("val_loss", val_loss, step=epoch)
    mlflow.log_metric("val_accuracy", val_acc, step=epoch)

    torch.save(model.state_dict(), last_model_path)
    mlflow.pytorch.log_model(model, "models/last_model")
    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')
        mlflow.pytorch.log_model(model, "models/best_model")

end_mlflow_experiment()        

  from .autonotebook import tqdm as notebook_tqdm
2024/09/21 02:16:54 INFO mlflow.tracking.fluent: Experiment with name 's1_s2_s3_matfiles_5part' does not exist. Creating a new experiment.


cuda


100%|██████████| 6/6 [00:57<00:00,  9.61s/it]


Epoch 1/500, Loss: 1.6235, Accuracy: 0.1959
Validation Loss: 1.6262, Validation Accuracy: 0.1989




New best model saved with accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 2/500, Loss: 1.6303, Accuracy: 0.1959
Validation Loss: 1.6102, Validation Accuracy: 0.1933


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 3/500, Loss: 1.6236, Accuracy: 0.1931
Validation Loss: 1.6129, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 4/500, Loss: 1.6204, Accuracy: 0.1966
Validation Loss: 1.6110, Validation Accuracy: 0.2017




New best model saved with accuracy: 0.2017


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 5/500, Loss: 1.6160, Accuracy: 0.1959
Validation Loss: 1.6091, Validation Accuracy: 0.2325




New best model saved with accuracy: 0.2325


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 6/500, Loss: 1.6124, Accuracy: 0.2044
Validation Loss: 1.6109, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 7/500, Loss: 1.6149, Accuracy: 0.1896
Validation Loss: 1.6107, Validation Accuracy: 0.2017


100%|██████████| 6/6 [00:27<00:00,  4.59s/it]


Epoch 8/500, Loss: 1.6130, Accuracy: 0.2022
Validation Loss: 1.6088, Validation Accuracy: 0.2129


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 9/500, Loss: 1.6157, Accuracy: 0.1910
Validation Loss: 1.6083, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 10/500, Loss: 1.6136, Accuracy: 0.2086
Validation Loss: 1.6113, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 11/500, Loss: 1.6125, Accuracy: 0.1959
Validation Loss: 1.6116, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 12/500, Loss: 1.6129, Accuracy: 0.1882
Validation Loss: 1.6094, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 13/500, Loss: 1.6104, Accuracy: 0.2191
Validation Loss: 1.6084, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 14/500, Loss: 1.6100, Accuracy: 0.1966
Validation Loss: 1.6107, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 15/500, Loss: 1.6164, Accuracy: 0.1973
Validation Loss: 1.6108, Validation Accuracy: 0.2297


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 16/500, Loss: 1.6137, Accuracy: 0.2086
Validation Loss: 1.6096, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 17/500, Loss: 1.6157, Accuracy: 0.1791
Validation Loss: 1.6083, Validation Accuracy: 0.2213


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 18/500, Loss: 1.6118, Accuracy: 0.2058
Validation Loss: 1.6098, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 19/500, Loss: 1.6080, Accuracy: 0.2268
Validation Loss: 1.6088, Validation Accuracy: 0.2185


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 20/500, Loss: 1.6114, Accuracy: 0.1987
Validation Loss: 1.6085, Validation Accuracy: 0.2241


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 21/500, Loss: 1.6104, Accuracy: 0.2008
Validation Loss: 1.6079, Validation Accuracy: 0.2017


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 22/500, Loss: 1.6108, Accuracy: 0.1931
Validation Loss: 1.6077, Validation Accuracy: 0.2353




New best model saved with accuracy: 0.2353


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 23/500, Loss: 1.6123, Accuracy: 0.2008
Validation Loss: 1.6084, Validation Accuracy: 0.2045


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 24/500, Loss: 1.6109, Accuracy: 0.2065
Validation Loss: 1.6078, Validation Accuracy: 0.2213


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 25/500, Loss: 1.6097, Accuracy: 0.2149
Validation Loss: 1.6083, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 26/500, Loss: 1.6112, Accuracy: 0.2072
Validation Loss: 1.6075, Validation Accuracy: 0.2269


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 27/500, Loss: 1.6115, Accuracy: 0.1861
Validation Loss: 1.6071, Validation Accuracy: 0.2241


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 28/500, Loss: 1.6088, Accuracy: 0.2135
Validation Loss: 1.6066, Validation Accuracy: 0.2381




New best model saved with accuracy: 0.2381


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 29/500, Loss: 1.6102, Accuracy: 0.2051
Validation Loss: 1.6071, Validation Accuracy: 0.2101


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 30/500, Loss: 1.6093, Accuracy: 0.2015
Validation Loss: 1.6109, Validation Accuracy: 0.2297


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 31/500, Loss: 1.6080, Accuracy: 0.2310
Validation Loss: 1.6086, Validation Accuracy: 0.1989


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 32/500, Loss: 1.6093, Accuracy: 0.2008
Validation Loss: 1.6025, Validation Accuracy: 0.2409




New best model saved with accuracy: 0.2409


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 33/500, Loss: 1.6104, Accuracy: 0.1980
Validation Loss: 1.6076, Validation Accuracy: 0.2157


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 34/500, Loss: 1.6066, Accuracy: 0.2079
Validation Loss: 1.6055, Validation Accuracy: 0.2381


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 35/500, Loss: 1.6084, Accuracy: 0.1959
Validation Loss: 1.6021, Validation Accuracy: 0.2521




New best model saved with accuracy: 0.2521


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 36/500, Loss: 1.6111, Accuracy: 0.1973
Validation Loss: 1.6086, Validation Accuracy: 0.2017


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 37/500, Loss: 1.6071, Accuracy: 0.2247
Validation Loss: 1.6041, Validation Accuracy: 0.2073


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 38/500, Loss: 1.6073, Accuracy: 0.2156
Validation Loss: 1.6104, Validation Accuracy: 0.2185


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 39/500, Loss: 1.6034, Accuracy: 0.2247
Validation Loss: 1.5993, Validation Accuracy: 0.2493


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 40/500, Loss: 1.6045, Accuracy: 0.2226
Validation Loss: 1.6091, Validation Accuracy: 0.2101


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 41/500, Loss: 1.6018, Accuracy: 0.2416
Validation Loss: 1.5979, Validation Accuracy: 0.2353


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 42/500, Loss: 1.6036, Accuracy: 0.2226
Validation Loss: 1.6165, Validation Accuracy: 0.2213


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 43/500, Loss: 1.5970, Accuracy: 0.2535
Validation Loss: 1.5950, Validation Accuracy: 0.2577




New best model saved with accuracy: 0.2577


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 44/500, Loss: 1.5892, Accuracy: 0.2423
Validation Loss: 1.6220, Validation Accuracy: 0.2185


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 45/500, Loss: 1.5761, Accuracy: 0.2718
Validation Loss: 1.5824, Validation Accuracy: 0.2689




New best model saved with accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 46/500, Loss: 1.5587, Accuracy: 0.2767
Validation Loss: 1.6128, Validation Accuracy: 0.2381


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 47/500, Loss: 1.5684, Accuracy: 0.2683
Validation Loss: 1.7128, Validation Accuracy: 0.2437


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 48/500, Loss: 1.5718, Accuracy: 0.2879
Validation Loss: 1.6515, Validation Accuracy: 0.2493


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 49/500, Loss: 1.5633, Accuracy: 0.2844
Validation Loss: 1.5775, Validation Accuracy: 0.2801




New best model saved with accuracy: 0.2801


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 50/500, Loss: 1.5541, Accuracy: 0.2837
Validation Loss: 1.6469, Validation Accuracy: 0.2269


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 51/500, Loss: 1.5827, Accuracy: 0.2556
Validation Loss: 1.6600, Validation Accuracy: 0.2409


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 52/500, Loss: 1.5550, Accuracy: 0.2886
Validation Loss: 1.5824, Validation Accuracy: 0.2885




New best model saved with accuracy: 0.2885


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 53/500, Loss: 1.5555, Accuracy: 0.2837
Validation Loss: 1.5822, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 54/500, Loss: 1.5554, Accuracy: 0.2830
Validation Loss: 1.5921, Validation Accuracy: 0.2577


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 55/500, Loss: 1.5525, Accuracy: 0.2858
Validation Loss: 1.6034, Validation Accuracy: 0.2521


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 56/500, Loss: 1.5509, Accuracy: 0.2978
Validation Loss: 1.5952, Validation Accuracy: 0.2549


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 57/500, Loss: 1.5432, Accuracy: 0.2992
Validation Loss: 1.5812, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 58/500, Loss: 1.5471, Accuracy: 0.2992
Validation Loss: 1.5953, Validation Accuracy: 0.2521


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 59/500, Loss: 1.5389, Accuracy: 0.3062
Validation Loss: 1.5807, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 60/500, Loss: 1.5357, Accuracy: 0.3090
Validation Loss: 1.5726, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 61/500, Loss: 1.5369, Accuracy: 0.3111
Validation Loss: 1.5758, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 62/500, Loss: 1.5296, Accuracy: 0.3230
Validation Loss: 1.5826, Validation Accuracy: 0.2801


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 63/500, Loss: 1.5262, Accuracy: 0.3062
Validation Loss: 1.5753, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 64/500, Loss: 1.5416, Accuracy: 0.2914
Validation Loss: 1.5725, Validation Accuracy: 0.2913




New best model saved with accuracy: 0.2913


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 65/500, Loss: 1.5321, Accuracy: 0.3013
Validation Loss: 1.5725, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 66/500, Loss: 1.5282, Accuracy: 0.3069
Validation Loss: 1.5745, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 67/500, Loss: 1.5308, Accuracy: 0.3174
Validation Loss: 1.5699, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 68/500, Loss: 1.5364, Accuracy: 0.3062
Validation Loss: 1.5677, Validation Accuracy: 0.2885


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 69/500, Loss: 1.5279, Accuracy: 0.3041
Validation Loss: 1.5717, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 70/500, Loss: 1.5244, Accuracy: 0.3230
Validation Loss: 1.5671, Validation Accuracy: 0.2969




New best model saved with accuracy: 0.2969


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 71/500, Loss: 1.5386, Accuracy: 0.3041
Validation Loss: 1.5676, Validation Accuracy: 0.2941


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 72/500, Loss: 1.5274, Accuracy: 0.3125
Validation Loss: 1.5780, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 73/500, Loss: 1.5310, Accuracy: 0.3139
Validation Loss: 1.5688, Validation Accuracy: 0.2801


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 74/500, Loss: 1.5247, Accuracy: 0.3202
Validation Loss: 1.5678, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 75/500, Loss: 1.5284, Accuracy: 0.3048
Validation Loss: 1.5683, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 76/500, Loss: 1.5342, Accuracy: 0.3083
Validation Loss: 1.5718, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 77/500, Loss: 1.5249, Accuracy: 0.3181
Validation Loss: 1.5677, Validation Accuracy: 0.2829


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 78/500, Loss: 1.5191, Accuracy: 0.3125
Validation Loss: 1.5687, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 79/500, Loss: 1.5265, Accuracy: 0.3027
Validation Loss: 1.5700, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 80/500, Loss: 1.5266, Accuracy: 0.3076
Validation Loss: 1.5710, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 81/500, Loss: 1.5269, Accuracy: 0.3230
Validation Loss: 1.5716, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 82/500, Loss: 1.5287, Accuracy: 0.3069
Validation Loss: 1.5712, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 83/500, Loss: 1.5276, Accuracy: 0.3104
Validation Loss: 1.5710, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 84/500, Loss: 1.5242, Accuracy: 0.3090
Validation Loss: 1.5710, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 85/500, Loss: 1.5280, Accuracy: 0.3195
Validation Loss: 1.5709, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 86/500, Loss: 1.5256, Accuracy: 0.3111
Validation Loss: 1.5706, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 87/500, Loss: 1.5330, Accuracy: 0.3181
Validation Loss: 1.5698, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 88/500, Loss: 1.5220, Accuracy: 0.3146
Validation Loss: 1.5689, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 89/500, Loss: 1.5361, Accuracy: 0.3062
Validation Loss: 1.5691, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 90/500, Loss: 1.5281, Accuracy: 0.3125
Validation Loss: 1.5687, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 91/500, Loss: 1.5243, Accuracy: 0.3006
Validation Loss: 1.5713, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 92/500, Loss: 1.5305, Accuracy: 0.3195
Validation Loss: 1.5685, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 93/500, Loss: 1.5293, Accuracy: 0.3174
Validation Loss: 1.5686, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 94/500, Loss: 1.5163, Accuracy: 0.3195
Validation Loss: 1.5778, Validation Accuracy: 0.2829


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 95/500, Loss: 1.5298, Accuracy: 0.3048
Validation Loss: 1.5769, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 96/500, Loss: 1.5318, Accuracy: 0.3069
Validation Loss: 1.5666, Validation Accuracy: 0.2885


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 97/500, Loss: 1.5270, Accuracy: 0.3111
Validation Loss: 1.5789, Validation Accuracy: 0.2801


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 98/500, Loss: 1.5245, Accuracy: 0.3272
Validation Loss: 1.5668, Validation Accuracy: 0.3053




New best model saved with accuracy: 0.3053


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 99/500, Loss: 1.5181, Accuracy: 0.3090
Validation Loss: 1.5740, Validation Accuracy: 0.2857


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 100/500, Loss: 1.5202, Accuracy: 0.3188
Validation Loss: 1.5663, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 101/500, Loss: 1.5319, Accuracy: 0.3125
Validation Loss: 1.5658, Validation Accuracy: 0.2941


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 102/500, Loss: 1.5377, Accuracy: 0.3041
Validation Loss: 1.5961, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 103/500, Loss: 1.5297, Accuracy: 0.3132
Validation Loss: 1.5681, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 104/500, Loss: 1.5267, Accuracy: 0.3167
Validation Loss: 1.5646, Validation Accuracy: 0.2885


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 105/500, Loss: 1.5311, Accuracy: 0.3069
Validation Loss: 1.5889, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 106/500, Loss: 1.5374, Accuracy: 0.2816
Validation Loss: 1.5614, Validation Accuracy: 0.2857


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 107/500, Loss: 1.5261, Accuracy: 0.3055
Validation Loss: 1.5836, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 108/500, Loss: 1.5359, Accuracy: 0.3076
Validation Loss: 1.5617, Validation Accuracy: 0.3137




New best model saved with accuracy: 0.3137


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 109/500, Loss: 1.5239, Accuracy: 0.3364
Validation Loss: 1.5647, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 110/500, Loss: 1.5286, Accuracy: 0.3048
Validation Loss: 1.5636, Validation Accuracy: 0.2857


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 111/500, Loss: 1.5276, Accuracy: 0.3048
Validation Loss: 1.5650, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 112/500, Loss: 1.5342, Accuracy: 0.3090
Validation Loss: 1.5698, Validation Accuracy: 0.2829


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 113/500, Loss: 1.5225, Accuracy: 0.3174
Validation Loss: 1.5623, Validation Accuracy: 0.3053


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 114/500, Loss: 1.5157, Accuracy: 0.3223
Validation Loss: 1.5832, Validation Accuracy: 0.2997


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 115/500, Loss: 1.5319, Accuracy: 0.3055
Validation Loss: 1.6794, Validation Accuracy: 0.2689


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 116/500, Loss: 1.5599, Accuracy: 0.2942
Validation Loss: 1.6514, Validation Accuracy: 0.2241


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 117/500, Loss: 1.5541, Accuracy: 0.2865
Validation Loss: 1.6056, Validation Accuracy: 0.2661


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 118/500, Loss: 1.5359, Accuracy: 0.3279
Validation Loss: 1.5475, Validation Accuracy: 0.3081


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 119/500, Loss: 1.5151, Accuracy: 0.3125
Validation Loss: 1.5563, Validation Accuracy: 0.2969


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 120/500, Loss: 1.5157, Accuracy: 0.3181
Validation Loss: 1.5603, Validation Accuracy: 0.2913


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 121/500, Loss: 1.5181, Accuracy: 0.3139
Validation Loss: 1.5541, Validation Accuracy: 0.3053


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 122/500, Loss: 1.5140, Accuracy: 0.3258
Validation Loss: 1.5553, Validation Accuracy: 0.2941


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 123/500, Loss: 1.5076, Accuracy: 0.3301
Validation Loss: 1.5546, Validation Accuracy: 0.3053


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 124/500, Loss: 1.5297, Accuracy: 0.3230
Validation Loss: 1.6387, Validation Accuracy: 0.2717


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 125/500, Loss: 1.5073, Accuracy: 0.3279
Validation Loss: 1.5472, Validation Accuracy: 0.2885


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 126/500, Loss: 1.5085, Accuracy: 0.3294
Validation Loss: 1.5807, Validation Accuracy: 0.2829


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 127/500, Loss: 1.5111, Accuracy: 0.3301
Validation Loss: 1.5614, Validation Accuracy: 0.2969


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 128/500, Loss: 1.4879, Accuracy: 0.3560
Validation Loss: 1.5404, Validation Accuracy: 0.3137


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 129/500, Loss: 1.4933, Accuracy: 0.3420
Validation Loss: 1.5821, Validation Accuracy: 0.2829


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 130/500, Loss: 1.5256, Accuracy: 0.3441
Validation Loss: 1.6657, Validation Accuracy: 0.2773


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 131/500, Loss: 1.5078, Accuracy: 0.3336
Validation Loss: 1.5509, Validation Accuracy: 0.3109


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 132/500, Loss: 1.4952, Accuracy: 0.3322
Validation Loss: 1.5536, Validation Accuracy: 0.3221




New best model saved with accuracy: 0.3221


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 133/500, Loss: 1.5008, Accuracy: 0.3287
Validation Loss: 1.5269, Validation Accuracy: 0.3221


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 134/500, Loss: 1.4887, Accuracy: 0.3539
Validation Loss: 1.5171, Validation Accuracy: 0.3501




New best model saved with accuracy: 0.3501


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 135/500, Loss: 1.4633, Accuracy: 0.3617
Validation Loss: 1.5733, Validation Accuracy: 0.3361


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 136/500, Loss: 1.5202, Accuracy: 0.3441
Validation Loss: 1.6027, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 137/500, Loss: 1.5192, Accuracy: 0.3069
Validation Loss: 1.5567, Validation Accuracy: 0.3081


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 138/500, Loss: 1.5106, Accuracy: 0.3244
Validation Loss: 1.5200, Validation Accuracy: 0.3025


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 139/500, Loss: 1.4591, Accuracy: 0.3757
Validation Loss: 1.5311, Validation Accuracy: 0.3305


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 140/500, Loss: 1.4868, Accuracy: 0.3659
Validation Loss: 1.6215, Validation Accuracy: 0.3221


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 141/500, Loss: 1.4922, Accuracy: 0.3610
Validation Loss: 1.6358, Validation Accuracy: 0.3361


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 142/500, Loss: 1.5365, Accuracy: 0.3076
Validation Loss: 1.6048, Validation Accuracy: 0.2885


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 143/500, Loss: 1.4547, Accuracy: 0.3820
Validation Loss: 1.5112, Validation Accuracy: 0.3249


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 144/500, Loss: 1.4642, Accuracy: 0.3652
Validation Loss: 1.5894, Validation Accuracy: 0.2969


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 145/500, Loss: 1.4693, Accuracy: 0.3827
Validation Loss: 1.5557, Validation Accuracy: 0.2857


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 146/500, Loss: 1.4528, Accuracy: 0.3813
Validation Loss: 1.6946, Validation Accuracy: 0.2633


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 147/500, Loss: 1.4895, Accuracy: 0.3265
Validation Loss: 1.6194, Validation Accuracy: 0.3053


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 148/500, Loss: 1.4642, Accuracy: 0.3701
Validation Loss: 1.6842, Validation Accuracy: 0.2745


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 149/500, Loss: 1.4855, Accuracy: 0.3673
Validation Loss: 1.5263, Validation Accuracy: 0.3613




New best model saved with accuracy: 0.3613


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 150/500, Loss: 1.4358, Accuracy: 0.3708
Validation Loss: 1.5286, Validation Accuracy: 0.3473


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 151/500, Loss: 1.4139, Accuracy: 0.4073
Validation Loss: 1.5253, Validation Accuracy: 0.3445


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 152/500, Loss: 1.4369, Accuracy: 0.4024
Validation Loss: 1.5374, Validation Accuracy: 0.3697




New best model saved with accuracy: 0.3697


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 153/500, Loss: 1.4374, Accuracy: 0.3673
Validation Loss: 1.5691, Validation Accuracy: 0.3081


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 154/500, Loss: 1.3972, Accuracy: 0.4094
Validation Loss: 1.5717, Validation Accuracy: 0.3305


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 155/500, Loss: 1.4496, Accuracy: 0.3750
Validation Loss: 1.5974, Validation Accuracy: 0.3053


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 156/500, Loss: 1.4127, Accuracy: 0.3904
Validation Loss: 1.6062, Validation Accuracy: 0.2969


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 157/500, Loss: 1.4469, Accuracy: 0.3652
Validation Loss: 1.5216, Validation Accuracy: 0.3249


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 158/500, Loss: 1.3932, Accuracy: 0.4094
Validation Loss: 1.5381, Validation Accuracy: 0.3137


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 159/500, Loss: 1.4207, Accuracy: 0.3897
Validation Loss: 1.5773, Validation Accuracy: 0.3221


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 160/500, Loss: 1.4123, Accuracy: 0.3947
Validation Loss: 1.4837, Validation Accuracy: 0.3277


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 161/500, Loss: 1.4069, Accuracy: 0.3975
Validation Loss: 1.5294, Validation Accuracy: 0.3221


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 162/500, Loss: 1.4042, Accuracy: 0.3940
Validation Loss: 1.5197, Validation Accuracy: 0.3529


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 163/500, Loss: 1.4366, Accuracy: 0.3933
Validation Loss: 1.5629, Validation Accuracy: 0.3529


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 164/500, Loss: 1.4355, Accuracy: 0.3778
Validation Loss: 1.4961, Validation Accuracy: 0.3417


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 165/500, Loss: 1.4278, Accuracy: 0.3841
Validation Loss: 1.4878, Validation Accuracy: 0.3445


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 166/500, Loss: 1.3735, Accuracy: 0.4129
Validation Loss: 1.4908, Validation Accuracy: 0.3529


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 167/500, Loss: 1.3612, Accuracy: 0.4277
Validation Loss: 1.4440, Validation Accuracy: 0.3950




New best model saved with accuracy: 0.3950


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 168/500, Loss: 1.4075, Accuracy: 0.3876
Validation Loss: 1.5470, Validation Accuracy: 0.3697


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 169/500, Loss: 1.3861, Accuracy: 0.4052
Validation Loss: 1.4925, Validation Accuracy: 0.3782


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 170/500, Loss: 1.3652, Accuracy: 0.4017
Validation Loss: 1.5044, Validation Accuracy: 0.3894


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 171/500, Loss: 1.4049, Accuracy: 0.3926
Validation Loss: 1.4944, Validation Accuracy: 0.3613


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 172/500, Loss: 1.3557, Accuracy: 0.4312
Validation Loss: 1.4692, Validation Accuracy: 0.3978




New best model saved with accuracy: 0.3978


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 173/500, Loss: 1.3526, Accuracy: 0.4206
Validation Loss: 1.5216, Validation Accuracy: 0.3641


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 174/500, Loss: 1.3350, Accuracy: 0.4424
Validation Loss: 1.4380, Validation Accuracy: 0.3754


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 175/500, Loss: 1.3165, Accuracy: 0.4649
Validation Loss: 1.5474, Validation Accuracy: 0.3333


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 176/500, Loss: 1.3056, Accuracy: 0.4586
Validation Loss: 1.4420, Validation Accuracy: 0.3725


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 177/500, Loss: 1.3151, Accuracy: 0.4663
Validation Loss: 1.4711, Validation Accuracy: 0.3782


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 178/500, Loss: 1.3334, Accuracy: 0.4438
Validation Loss: 1.4975, Validation Accuracy: 0.3725


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 179/500, Loss: 1.3380, Accuracy: 0.4389
Validation Loss: 1.4786, Validation Accuracy: 0.3641


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 180/500, Loss: 1.3795, Accuracy: 0.4178
Validation Loss: 1.3805, Validation Accuracy: 0.4090




New best model saved with accuracy: 0.4090


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 181/500, Loss: 1.3509, Accuracy: 0.4277
Validation Loss: 1.5461, Validation Accuracy: 0.3529


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 182/500, Loss: 1.3252, Accuracy: 0.4382
Validation Loss: 1.4399, Validation Accuracy: 0.3641


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 183/500, Loss: 1.3370, Accuracy: 0.4452
Validation Loss: 1.5116, Validation Accuracy: 0.3922


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 184/500, Loss: 1.3092, Accuracy: 0.4452
Validation Loss: 1.4906, Validation Accuracy: 0.3838


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 185/500, Loss: 1.3173, Accuracy: 0.4452
Validation Loss: 1.5102, Validation Accuracy: 0.3725


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 186/500, Loss: 1.2902, Accuracy: 0.4600
Validation Loss: 1.4206, Validation Accuracy: 0.4146




New best model saved with accuracy: 0.4146


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 187/500, Loss: 1.2770, Accuracy: 0.4846
Validation Loss: 1.5367, Validation Accuracy: 0.3529


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 188/500, Loss: 1.2704, Accuracy: 0.4923
Validation Loss: 1.5282, Validation Accuracy: 0.3697


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 189/500, Loss: 1.3307, Accuracy: 0.4529
Validation Loss: 1.5248, Validation Accuracy: 0.3866


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


Epoch 190/500, Loss: 1.3006, Accuracy: 0.4537
Validation Loss: 1.4506, Validation Accuracy: 0.3866


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 191/500, Loss: 1.2486, Accuracy: 0.4895
Validation Loss: 1.5803, Validation Accuracy: 0.3417


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 192/500, Loss: 1.2518, Accuracy: 0.4782
Validation Loss: 1.4278, Validation Accuracy: 0.4090


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 193/500, Loss: 1.2355, Accuracy: 0.4965
Validation Loss: 1.4997, Validation Accuracy: 0.3669


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 194/500, Loss: 1.2369, Accuracy: 0.4965
Validation Loss: 1.6124, Validation Accuracy: 0.3389


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 195/500, Loss: 1.3148, Accuracy: 0.4551
Validation Loss: 1.5031, Validation Accuracy: 0.3978


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 196/500, Loss: 1.3113, Accuracy: 0.4473
Validation Loss: 1.4143, Validation Accuracy: 0.4146


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 197/500, Loss: 1.2537, Accuracy: 0.4881
Validation Loss: 1.4343, Validation Accuracy: 0.4146


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 198/500, Loss: 1.2532, Accuracy: 0.5070
Validation Loss: 1.4524, Validation Accuracy: 0.4118


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 199/500, Loss: 1.2370, Accuracy: 0.5126
Validation Loss: 1.4024, Validation Accuracy: 0.4398




New best model saved with accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 200/500, Loss: 1.2191, Accuracy: 0.5014
Validation Loss: 1.4680, Validation Accuracy: 0.3978


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 201/500, Loss: 1.2782, Accuracy: 0.4761
Validation Loss: 1.3723, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 202/500, Loss: 1.2390, Accuracy: 0.4754
Validation Loss: 1.5340, Validation Accuracy: 0.3697


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 203/500, Loss: 1.2353, Accuracy: 0.5070
Validation Loss: 1.5558, Validation Accuracy: 0.3697


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 204/500, Loss: 1.2397, Accuracy: 0.4796
Validation Loss: 1.4105, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 205/500, Loss: 1.2241, Accuracy: 0.5091
Validation Loss: 1.4296, Validation Accuracy: 0.4510




New best model saved with accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 206/500, Loss: 1.2141, Accuracy: 0.5169
Validation Loss: 1.3503, Validation Accuracy: 0.4594




New best model saved with accuracy: 0.4594


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 207/500, Loss: 1.2225, Accuracy: 0.4958
Validation Loss: 1.4328, Validation Accuracy: 0.3894


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 208/500, Loss: 1.2151, Accuracy: 0.5309
Validation Loss: 1.5344, Validation Accuracy: 0.3725


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 209/500, Loss: 1.2248, Accuracy: 0.5028
Validation Loss: 1.4704, Validation Accuracy: 0.3866


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 210/500, Loss: 1.2063, Accuracy: 0.5288
Validation Loss: 1.4266, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 211/500, Loss: 1.2084, Accuracy: 0.5260
Validation Loss: 1.4297, Validation Accuracy: 0.4230


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 212/500, Loss: 1.2003, Accuracy: 0.5190
Validation Loss: 1.4433, Validation Accuracy: 0.4062


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 213/500, Loss: 1.2064, Accuracy: 0.5154
Validation Loss: 1.4193, Validation Accuracy: 0.4314


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 214/500, Loss: 1.1814, Accuracy: 0.5400
Validation Loss: 1.5087, Validation Accuracy: 0.3866


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 215/500, Loss: 1.1843, Accuracy: 0.5379
Validation Loss: 1.4740, Validation Accuracy: 0.4202


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 216/500, Loss: 1.1995, Accuracy: 0.5253
Validation Loss: 1.3953, Validation Accuracy: 0.4426


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 217/500, Loss: 1.1833, Accuracy: 0.5372
Validation Loss: 1.4652, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 218/500, Loss: 1.1814, Accuracy: 0.5337
Validation Loss: 1.5219, Validation Accuracy: 0.3978


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 219/500, Loss: 1.1863, Accuracy: 0.5183
Validation Loss: 1.4305, Validation Accuracy: 0.4286


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 220/500, Loss: 1.1801, Accuracy: 0.5365
Validation Loss: 1.4275, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 221/500, Loss: 1.2011, Accuracy: 0.5154
Validation Loss: 1.4818, Validation Accuracy: 0.4202


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 222/500, Loss: 1.1814, Accuracy: 0.5232
Validation Loss: 1.4577, Validation Accuracy: 0.4314


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 223/500, Loss: 1.1765, Accuracy: 0.5260
Validation Loss: 1.4552, Validation Accuracy: 0.4202


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 224/500, Loss: 1.1753, Accuracy: 0.5225
Validation Loss: 1.5048, Validation Accuracy: 0.4230


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 225/500, Loss: 1.1698, Accuracy: 0.5358
Validation Loss: 1.4025, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 226/500, Loss: 1.1601, Accuracy: 0.5323
Validation Loss: 1.4877, Validation Accuracy: 0.4174


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 227/500, Loss: 1.1631, Accuracy: 0.5225
Validation Loss: 1.4637, Validation Accuracy: 0.4230


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 228/500, Loss: 1.1631, Accuracy: 0.5407
Validation Loss: 1.4081, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 229/500, Loss: 1.1611, Accuracy: 0.5274
Validation Loss: 1.4928, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 230/500, Loss: 1.1551, Accuracy: 0.5372
Validation Loss: 1.4193, Validation Accuracy: 0.4538


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 231/500, Loss: 1.1582, Accuracy: 0.5253
Validation Loss: 1.4376, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 232/500, Loss: 1.1465, Accuracy: 0.5372
Validation Loss: 1.4652, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 233/500, Loss: 1.1728, Accuracy: 0.5323
Validation Loss: 1.4341, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 234/500, Loss: 1.1483, Accuracy: 0.5471
Validation Loss: 1.4680, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 235/500, Loss: 1.1543, Accuracy: 0.5288
Validation Loss: 1.4160, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 236/500, Loss: 1.1420, Accuracy: 0.5527
Validation Loss: 1.4638, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 237/500, Loss: 1.1343, Accuracy: 0.5520
Validation Loss: 1.4260, Validation Accuracy: 0.4566


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 238/500, Loss: 1.1476, Accuracy: 0.5281
Validation Loss: 1.4579, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 239/500, Loss: 1.1355, Accuracy: 0.5358
Validation Loss: 1.4612, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 240/500, Loss: 1.1439, Accuracy: 0.5597
Validation Loss: 1.4470, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 241/500, Loss: 1.1446, Accuracy: 0.5456
Validation Loss: 1.4453, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 242/500, Loss: 1.1448, Accuracy: 0.5456
Validation Loss: 1.4562, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 243/500, Loss: 1.1610, Accuracy: 0.5414
Validation Loss: 1.4656, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 244/500, Loss: 1.1274, Accuracy: 0.5555
Validation Loss: 1.4579, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 245/500, Loss: 1.1435, Accuracy: 0.5456
Validation Loss: 1.4522, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 246/500, Loss: 1.1414, Accuracy: 0.5520
Validation Loss: 1.4474, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 247/500, Loss: 1.1508, Accuracy: 0.5583
Validation Loss: 1.4470, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 248/500, Loss: 1.1445, Accuracy: 0.5372
Validation Loss: 1.4468, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 249/500, Loss: 1.1479, Accuracy: 0.5449
Validation Loss: 1.4468, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 250/500, Loss: 1.1298, Accuracy: 0.5597
Validation Loss: 1.4472, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 251/500, Loss: 1.1467, Accuracy: 0.5456
Validation Loss: 1.4471, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 252/500, Loss: 1.1350, Accuracy: 0.5541
Validation Loss: 1.4475, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 253/500, Loss: 1.1424, Accuracy: 0.5449
Validation Loss: 1.4478, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 254/500, Loss: 1.1401, Accuracy: 0.5316
Validation Loss: 1.4511, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 255/500, Loss: 1.1212, Accuracy: 0.5667
Validation Loss: 1.4497, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 256/500, Loss: 1.1419, Accuracy: 0.5506
Validation Loss: 1.4482, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 257/500, Loss: 1.1457, Accuracy: 0.5520
Validation Loss: 1.4523, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 258/500, Loss: 1.1557, Accuracy: 0.5393
Validation Loss: 1.4536, Validation Accuracy: 0.4426


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 259/500, Loss: 1.1452, Accuracy: 0.5442
Validation Loss: 1.4468, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 260/500, Loss: 1.1268, Accuracy: 0.5576
Validation Loss: 1.4432, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 261/500, Loss: 1.1543, Accuracy: 0.5456
Validation Loss: 1.4681, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 262/500, Loss: 1.1336, Accuracy: 0.5471
Validation Loss: 1.4716, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 263/500, Loss: 1.1348, Accuracy: 0.5372
Validation Loss: 1.4354, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 264/500, Loss: 1.1466, Accuracy: 0.5407
Validation Loss: 1.4626, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 265/500, Loss: 1.1465, Accuracy: 0.5660
Validation Loss: 1.4824, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 266/500, Loss: 1.1560, Accuracy: 0.5534
Validation Loss: 1.4398, Validation Accuracy: 0.4538


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 267/500, Loss: 1.1355, Accuracy: 0.5541
Validation Loss: 1.4568, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:26<00:00,  4.50s/it]


Epoch 268/500, Loss: 1.1440, Accuracy: 0.5449
Validation Loss: 1.4694, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 269/500, Loss: 1.1446, Accuracy: 0.5471
Validation Loss: 1.4529, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 270/500, Loss: 1.1323, Accuracy: 0.5569
Validation Loss: 1.4573, Validation Accuracy: 0.4538


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 271/500, Loss: 1.1694, Accuracy: 0.5253
Validation Loss: 1.4756, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 272/500, Loss: 1.1325, Accuracy: 0.5407
Validation Loss: 1.4353, Validation Accuracy: 0.4566


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 273/500, Loss: 1.1230, Accuracy: 0.5597
Validation Loss: 1.4549, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 274/500, Loss: 1.1704, Accuracy: 0.5260
Validation Loss: 1.4930, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 275/500, Loss: 1.1281, Accuracy: 0.5485
Validation Loss: 1.4859, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 276/500, Loss: 1.1380, Accuracy: 0.5492
Validation Loss: 1.3885, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 277/500, Loss: 1.1633, Accuracy: 0.5456
Validation Loss: 1.5460, Validation Accuracy: 0.4314


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 278/500, Loss: 1.1390, Accuracy: 0.5506
Validation Loss: 1.4695, Validation Accuracy: 0.4426


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 279/500, Loss: 1.1475, Accuracy: 0.5442
Validation Loss: 1.4079, Validation Accuracy: 0.4706




New best model saved with accuracy: 0.4706


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 280/500, Loss: 1.1317, Accuracy: 0.5499
Validation Loss: 1.5059, Validation Accuracy: 0.4146


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 281/500, Loss: 1.1521, Accuracy: 0.5407
Validation Loss: 1.4490, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 282/500, Loss: 1.1306, Accuracy: 0.5520
Validation Loss: 1.4574, Validation Accuracy: 0.4230


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 283/500, Loss: 1.1461, Accuracy: 0.5506
Validation Loss: 1.5362, Validation Accuracy: 0.4202


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 284/500, Loss: 1.1355, Accuracy: 0.5611
Validation Loss: 1.4702, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 285/500, Loss: 1.1405, Accuracy: 0.5344
Validation Loss: 1.4753, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 286/500, Loss: 1.1601, Accuracy: 0.5421
Validation Loss: 1.4976, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 287/500, Loss: 1.1175, Accuracy: 0.5506
Validation Loss: 1.4900, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 288/500, Loss: 1.1295, Accuracy: 0.5407
Validation Loss: 1.5213, Validation Accuracy: 0.4426


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 289/500, Loss: 1.1471, Accuracy: 0.5506
Validation Loss: 1.5562, Validation Accuracy: 0.4118


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 290/500, Loss: 1.1681, Accuracy: 0.5183
Validation Loss: 1.2987, Validation Accuracy: 0.4846




New best model saved with accuracy: 0.4846


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 291/500, Loss: 1.1537, Accuracy: 0.5435
Validation Loss: 1.4439, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 292/500, Loss: 1.1793, Accuracy: 0.5162
Validation Loss: 1.5390, Validation Accuracy: 0.4146


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 293/500, Loss: 1.1379, Accuracy: 0.5435
Validation Loss: 1.4237, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 294/500, Loss: 1.1310, Accuracy: 0.5667
Validation Loss: 1.3970, Validation Accuracy: 0.4594


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 295/500, Loss: 1.1011, Accuracy: 0.5653
Validation Loss: 1.5226, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 296/500, Loss: 1.1465, Accuracy: 0.5428
Validation Loss: 1.4445, Validation Accuracy: 0.4314


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 297/500, Loss: 1.1523, Accuracy: 0.5562
Validation Loss: 1.4169, Validation Accuracy: 0.4566


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 298/500, Loss: 1.1306, Accuracy: 0.5562
Validation Loss: 1.5646, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 299/500, Loss: 1.1348, Accuracy: 0.5618
Validation Loss: 1.5490, Validation Accuracy: 0.4230


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 300/500, Loss: 1.1522, Accuracy: 0.5414
Validation Loss: 1.5243, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 301/500, Loss: 1.1139, Accuracy: 0.5442
Validation Loss: 1.5144, Validation Accuracy: 0.4482


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 302/500, Loss: 1.1584, Accuracy: 0.5386
Validation Loss: 1.4651, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 303/500, Loss: 1.1288, Accuracy: 0.5541
Validation Loss: 1.5094, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 304/500, Loss: 1.1209, Accuracy: 0.5618
Validation Loss: 1.4716, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 305/500, Loss: 1.1580, Accuracy: 0.5288
Validation Loss: 1.5417, Validation Accuracy: 0.4286


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 306/500, Loss: 1.1038, Accuracy: 0.5611
Validation Loss: 1.4251, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 307/500, Loss: 1.1284, Accuracy: 0.5463
Validation Loss: 1.5436, Validation Accuracy: 0.4202


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 308/500, Loss: 1.1441, Accuracy: 0.5534
Validation Loss: 1.4530, Validation Accuracy: 0.4538


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 309/500, Loss: 1.1238, Accuracy: 0.5590
Validation Loss: 1.5656, Validation Accuracy: 0.4258


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 310/500, Loss: 1.1444, Accuracy: 0.5527
Validation Loss: 1.4528, Validation Accuracy: 0.4566


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 311/500, Loss: 1.1318, Accuracy: 0.5562
Validation Loss: 1.4916, Validation Accuracy: 0.4398


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 312/500, Loss: 1.1303, Accuracy: 0.5393
Validation Loss: 1.4880, Validation Accuracy: 0.4370


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


Epoch 313/500, Loss: 1.1159, Accuracy: 0.5597
Validation Loss: 1.4908, Validation Accuracy: 0.4650


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 314/500, Loss: 1.1049, Accuracy: 0.5583
Validation Loss: 1.4131, Validation Accuracy: 0.4594


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 315/500, Loss: 1.1034, Accuracy: 0.5583
Validation Loss: 1.3954, Validation Accuracy: 0.4790


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 316/500, Loss: 1.1136, Accuracy: 0.5463
Validation Loss: 1.3909, Validation Accuracy: 0.4790


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 317/500, Loss: 1.1223, Accuracy: 0.5520
Validation Loss: 1.4547, Validation Accuracy: 0.4454


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 318/500, Loss: 1.1346, Accuracy: 0.5344
Validation Loss: 1.5007, Validation Accuracy: 0.4426


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 319/500, Loss: 1.1610, Accuracy: 0.5407
Validation Loss: 1.5074, Validation Accuracy: 0.4118


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 320/500, Loss: 1.1631, Accuracy: 0.5274
Validation Loss: 1.4848, Validation Accuracy: 0.4006


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 321/500, Loss: 1.1299, Accuracy: 0.5442
Validation Loss: 1.3474, Validation Accuracy: 0.4986




New best model saved with accuracy: 0.4986


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 322/500, Loss: 1.0941, Accuracy: 0.5541
Validation Loss: 1.3669, Validation Accuracy: 0.4874


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 323/500, Loss: 1.0812, Accuracy: 0.5744
Validation Loss: 1.3636, Validation Accuracy: 0.4734


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 324/500, Loss: 1.1367, Accuracy: 0.5421
Validation Loss: 1.4838, Validation Accuracy: 0.4734


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 325/500, Loss: 1.0981, Accuracy: 0.5681
Validation Loss: 1.4265, Validation Accuracy: 0.4734


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 326/500, Loss: 1.0838, Accuracy: 0.5716
Validation Loss: 1.3411, Validation Accuracy: 0.4874


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 327/500, Loss: 1.0600, Accuracy: 0.5864
Validation Loss: 1.3440, Validation Accuracy: 0.4706


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 328/500, Loss: 1.0918, Accuracy: 0.5688
Validation Loss: 1.2513, Validation Accuracy: 0.4902


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 329/500, Loss: 1.0949, Accuracy: 0.5794
Validation Loss: 1.3429, Validation Accuracy: 0.4958


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 330/500, Loss: 1.0985, Accuracy: 0.5688
Validation Loss: 1.3857, Validation Accuracy: 0.4874


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 331/500, Loss: 1.0673, Accuracy: 0.5730
Validation Loss: 1.4306, Validation Accuracy: 0.4594


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 332/500, Loss: 1.1220, Accuracy: 0.5681
Validation Loss: 1.4884, Validation Accuracy: 0.4342


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 333/500, Loss: 1.0885, Accuracy: 0.5695
Validation Loss: 1.3110, Validation Accuracy: 0.4846


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 334/500, Loss: 1.0630, Accuracy: 0.5941
Validation Loss: 1.3910, Validation Accuracy: 0.4734


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 335/500, Loss: 1.0498, Accuracy: 0.5892
Validation Loss: 1.3973, Validation Accuracy: 0.4678


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 336/500, Loss: 1.0509, Accuracy: 0.5885
Validation Loss: 1.3962, Validation Accuracy: 0.4594


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 337/500, Loss: 1.0523, Accuracy: 0.5787
Validation Loss: 1.3563, Validation Accuracy: 0.4902


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 338/500, Loss: 1.0343, Accuracy: 0.6032
Validation Loss: 1.4273, Validation Accuracy: 0.4902


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 339/500, Loss: 1.0330, Accuracy: 0.5941
Validation Loss: 1.3737, Validation Accuracy: 0.5014




New best model saved with accuracy: 0.5014


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


Epoch 340/500, Loss: 1.0129, Accuracy: 0.6096
Validation Loss: 1.3893, Validation Accuracy: 0.4818


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 341/500, Loss: 1.0306, Accuracy: 0.5976
Validation Loss: 1.2644, Validation Accuracy: 0.5182




New best model saved with accuracy: 0.5182


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 342/500, Loss: 1.0196, Accuracy: 0.6004
Validation Loss: 1.2490, Validation Accuracy: 0.5126


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 343/500, Loss: 1.0635, Accuracy: 0.5822
Validation Loss: 1.4129, Validation Accuracy: 0.4846


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 344/500, Loss: 1.0131, Accuracy: 0.6110
Validation Loss: 1.4452, Validation Accuracy: 0.4762


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 345/500, Loss: 1.0603, Accuracy: 0.5948
Validation Loss: 1.3642, Validation Accuracy: 0.4706


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 346/500, Loss: 1.0310, Accuracy: 0.5990
Validation Loss: 1.2869, Validation Accuracy: 0.5266




New best model saved with accuracy: 0.5266


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 347/500, Loss: 1.0025, Accuracy: 0.6201
Validation Loss: 1.3768, Validation Accuracy: 0.5014


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 348/500, Loss: 0.9749, Accuracy: 0.6292
Validation Loss: 1.3140, Validation Accuracy: 0.5210


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 349/500, Loss: 1.0207, Accuracy: 0.6103
Validation Loss: 1.2856, Validation Accuracy: 0.5238


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 350/500, Loss: 1.0079, Accuracy: 0.5955
Validation Loss: 1.4986, Validation Accuracy: 0.4510


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 351/500, Loss: 1.0182, Accuracy: 0.6173
Validation Loss: 1.3689, Validation Accuracy: 0.4930


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 352/500, Loss: 0.9950, Accuracy: 0.6320
Validation Loss: 1.3769, Validation Accuracy: 0.5154


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 353/500, Loss: 0.9496, Accuracy: 0.6348
Validation Loss: 1.3259, Validation Accuracy: 0.5126


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 354/500, Loss: 0.9463, Accuracy: 0.6552
Validation Loss: 1.2799, Validation Accuracy: 0.5238


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 355/500, Loss: 0.9564, Accuracy: 0.6404
Validation Loss: 1.2727, Validation Accuracy: 0.5266


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 356/500, Loss: 0.9881, Accuracy: 0.6271
Validation Loss: 1.3177, Validation Accuracy: 0.5266


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 357/500, Loss: 1.0228, Accuracy: 0.6067
Validation Loss: 1.4072, Validation Accuracy: 0.4958


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 358/500, Loss: 0.9948, Accuracy: 0.6222
Validation Loss: 1.4387, Validation Accuracy: 0.4986


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 359/500, Loss: 0.9677, Accuracy: 0.6362
Validation Loss: 1.3007, Validation Accuracy: 0.5294




New best model saved with accuracy: 0.5294


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 360/500, Loss: 0.9512, Accuracy: 0.6426
Validation Loss: 1.3631, Validation Accuracy: 0.5154


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 361/500, Loss: 0.9352, Accuracy: 0.6622
Validation Loss: 1.4435, Validation Accuracy: 0.4986


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 362/500, Loss: 0.9771, Accuracy: 0.6173
Validation Loss: 1.3027, Validation Accuracy: 0.5182


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 363/500, Loss: 0.9593, Accuracy: 0.6257
Validation Loss: 1.2499, Validation Accuracy: 0.5098


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 364/500, Loss: 0.9392, Accuracy: 0.6355
Validation Loss: 1.2439, Validation Accuracy: 0.5378




New best model saved with accuracy: 0.5378


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 365/500, Loss: 0.9104, Accuracy: 0.6545
Validation Loss: 1.3900, Validation Accuracy: 0.5070


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 366/500, Loss: 0.9258, Accuracy: 0.6608
Validation Loss: 1.2970, Validation Accuracy: 0.5266


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 367/500, Loss: 0.9109, Accuracy: 0.6678
Validation Loss: 1.2681, Validation Accuracy: 0.5434




New best model saved with accuracy: 0.5434


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 368/500, Loss: 0.9150, Accuracy: 0.6538
Validation Loss: 1.2837, Validation Accuracy: 0.5490




New best model saved with accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 369/500, Loss: 0.9066, Accuracy: 0.6601
Validation Loss: 1.3038, Validation Accuracy: 0.5434


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 370/500, Loss: 0.8817, Accuracy: 0.6629
Validation Loss: 1.3258, Validation Accuracy: 0.5294


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 371/500, Loss: 0.9224, Accuracy: 0.6692
Validation Loss: 1.4040, Validation Accuracy: 0.5098


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 372/500, Loss: 0.9104, Accuracy: 0.6538
Validation Loss: 1.2330, Validation Accuracy: 0.5294


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 373/500, Loss: 0.8892, Accuracy: 0.6756
Validation Loss: 1.2120, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 374/500, Loss: 0.8911, Accuracy: 0.6742
Validation Loss: 1.3538, Validation Accuracy: 0.5238


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 375/500, Loss: 0.8862, Accuracy: 0.6650
Validation Loss: 1.2682, Validation Accuracy: 0.5546




New best model saved with accuracy: 0.5546


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 376/500, Loss: 0.8593, Accuracy: 0.6882
Validation Loss: 1.2140, Validation Accuracy: 0.5434


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 377/500, Loss: 0.8444, Accuracy: 0.6959
Validation Loss: 1.2945, Validation Accuracy: 0.5602




New best model saved with accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 378/500, Loss: 0.8449, Accuracy: 0.6763
Validation Loss: 1.3230, Validation Accuracy: 0.5406


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 379/500, Loss: 0.8400, Accuracy: 0.6868
Validation Loss: 1.2066, Validation Accuracy: 0.5742




New best model saved with accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 380/500, Loss: 0.8647, Accuracy: 0.6728
Validation Loss: 1.3256, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 381/500, Loss: 0.8521, Accuracy: 0.6868
Validation Loss: 1.2424, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 382/500, Loss: 0.8388, Accuracy: 0.6917
Validation Loss: 1.2375, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 383/500, Loss: 0.8440, Accuracy: 0.6840
Validation Loss: 1.3008, Validation Accuracy: 0.5462


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 384/500, Loss: 0.8227, Accuracy: 0.7037
Validation Loss: 1.2911, Validation Accuracy: 0.5462


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 385/500, Loss: 0.7988, Accuracy: 0.7184
Validation Loss: 1.2690, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 386/500, Loss: 0.7992, Accuracy: 0.7156
Validation Loss: 1.2998, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 387/500, Loss: 0.8200, Accuracy: 0.6945
Validation Loss: 1.2808, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 388/500, Loss: 0.8215, Accuracy: 0.7015
Validation Loss: 1.2702, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 389/500, Loss: 0.7992, Accuracy: 0.7100
Validation Loss: 1.2444, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 390/500, Loss: 0.8064, Accuracy: 0.6931
Validation Loss: 1.3010, Validation Accuracy: 0.5462


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 391/500, Loss: 0.8069, Accuracy: 0.6882
Validation Loss: 1.2458, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 392/500, Loss: 0.8131, Accuracy: 0.7072
Validation Loss: 1.2906, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 393/500, Loss: 0.8013, Accuracy: 0.7086
Validation Loss: 1.2308, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 394/500, Loss: 0.8009, Accuracy: 0.7100
Validation Loss: 1.2892, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 395/500, Loss: 0.8016, Accuracy: 0.7065
Validation Loss: 1.2353, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 396/500, Loss: 0.7877, Accuracy: 0.7037
Validation Loss: 1.2780, Validation Accuracy: 0.5546


100%|██████████| 6/6 [00:27<00:00,  4.61s/it]


Epoch 397/500, Loss: 0.7837, Accuracy: 0.7205
Validation Loss: 1.2389, Validation Accuracy: 0.5546


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 398/500, Loss: 0.7894, Accuracy: 0.7100
Validation Loss: 1.2740, Validation Accuracy: 0.5546


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 399/500, Loss: 0.7875, Accuracy: 0.7093
Validation Loss: 1.2500, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 400/500, Loss: 0.7594, Accuracy: 0.7303
Validation Loss: 1.2468, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 401/500, Loss: 0.7697, Accuracy: 0.7233
Validation Loss: 1.2301, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 402/500, Loss: 0.7814, Accuracy: 0.7226
Validation Loss: 1.2523, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 403/500, Loss: 0.7767, Accuracy: 0.7296
Validation Loss: 1.2566, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 404/500, Loss: 0.7579, Accuracy: 0.7240
Validation Loss: 1.2406, Validation Accuracy: 0.5798




New best model saved with accuracy: 0.5798


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 405/500, Loss: 0.7853, Accuracy: 0.7135
Validation Loss: 1.2548, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 406/500, Loss: 0.7630, Accuracy: 0.7275
Validation Loss: 1.2529, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 407/500, Loss: 0.7799, Accuracy: 0.7254
Validation Loss: 1.2486, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 408/500, Loss: 0.7758, Accuracy: 0.7233
Validation Loss: 1.2473, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 409/500, Loss: 0.7572, Accuracy: 0.7346
Validation Loss: 1.2452, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 410/500, Loss: 0.7463, Accuracy: 0.7346
Validation Loss: 1.2425, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 411/500, Loss: 0.7770, Accuracy: 0.7093
Validation Loss: 1.2441, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 412/500, Loss: 0.7765, Accuracy: 0.7086
Validation Loss: 1.2457, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 413/500, Loss: 0.7648, Accuracy: 0.7254
Validation Loss: 1.2458, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 414/500, Loss: 0.7837, Accuracy: 0.7233
Validation Loss: 1.2458, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 415/500, Loss: 0.7709, Accuracy: 0.7177
Validation Loss: 1.2459, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 416/500, Loss: 0.7744, Accuracy: 0.7233
Validation Loss: 1.2463, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 417/500, Loss: 0.7882, Accuracy: 0.7156
Validation Loss: 1.2467, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 418/500, Loss: 0.7753, Accuracy: 0.7367
Validation Loss: 1.2475, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 419/500, Loss: 0.7719, Accuracy: 0.7268
Validation Loss: 1.2504, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 420/500, Loss: 0.7775, Accuracy: 0.7261
Validation Loss: 1.2457, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


Epoch 421/500, Loss: 0.7608, Accuracy: 0.7310
Validation Loss: 1.2352, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 422/500, Loss: 0.7447, Accuracy: 0.7268
Validation Loss: 1.2399, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 423/500, Loss: 0.7953, Accuracy: 0.7212
Validation Loss: 1.2523, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 424/500, Loss: 0.7450, Accuracy: 0.7374
Validation Loss: 1.2536, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 425/500, Loss: 0.7861, Accuracy: 0.7184
Validation Loss: 1.2381, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 426/500, Loss: 0.7690, Accuracy: 0.7247
Validation Loss: 1.2527, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 427/500, Loss: 0.7661, Accuracy: 0.7226
Validation Loss: 1.2274, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 428/500, Loss: 0.7742, Accuracy: 0.7219
Validation Loss: 1.2655, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 429/500, Loss: 0.7819, Accuracy: 0.7205
Validation Loss: 1.2422, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 430/500, Loss: 0.7700, Accuracy: 0.7247
Validation Loss: 1.2272, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 431/500, Loss: 0.7958, Accuracy: 0.7191
Validation Loss: 1.2650, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 432/500, Loss: 0.7627, Accuracy: 0.7226
Validation Loss: 1.2202, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 433/500, Loss: 0.7820, Accuracy: 0.7163
Validation Loss: 1.2970, Validation Accuracy: 0.5574


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 434/500, Loss: 0.7897, Accuracy: 0.7086
Validation Loss: 1.2227, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 435/500, Loss: 0.7530, Accuracy: 0.7163
Validation Loss: 1.2481, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 436/500, Loss: 0.7717, Accuracy: 0.7191
Validation Loss: 1.2548, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 437/500, Loss: 0.7805, Accuracy: 0.7219
Validation Loss: 1.2712, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 438/500, Loss: 0.7900, Accuracy: 0.7015
Validation Loss: 1.2408, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 439/500, Loss: 0.7638, Accuracy: 0.7261
Validation Loss: 1.2409, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 440/500, Loss: 0.7900, Accuracy: 0.7086
Validation Loss: 1.2614, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 441/500, Loss: 0.7778, Accuracy: 0.7093
Validation Loss: 1.2788, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 442/500, Loss: 0.7565, Accuracy: 0.7388
Validation Loss: 1.2251, Validation Accuracy: 0.5770


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 443/500, Loss: 0.7319, Accuracy: 0.7409
Validation Loss: 1.2130, Validation Accuracy: 0.5854




New best model saved with accuracy: 0.5854


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 444/500, Loss: 0.7711, Accuracy: 0.7170
Validation Loss: 1.2110, Validation Accuracy: 0.5770


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 445/500, Loss: 0.7630, Accuracy: 0.7212
Validation Loss: 1.1916, Validation Accuracy: 0.5770


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 446/500, Loss: 0.7910, Accuracy: 0.7163
Validation Loss: 1.2228, Validation Accuracy: 0.5910




New best model saved with accuracy: 0.5910


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


Epoch 447/500, Loss: 0.7491, Accuracy: 0.7360
Validation Loss: 1.2007, Validation Accuracy: 0.5910


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 448/500, Loss: 0.8032, Accuracy: 0.7114
Validation Loss: 1.2551, Validation Accuracy: 0.5770


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 449/500, Loss: 0.7670, Accuracy: 0.7233
Validation Loss: 1.2057, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 450/500, Loss: 0.7728, Accuracy: 0.7247
Validation Loss: 1.1890, Validation Accuracy: 0.5798


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 451/500, Loss: 0.7786, Accuracy: 0.7240
Validation Loss: 1.2626, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 452/500, Loss: 0.7674, Accuracy: 0.7275
Validation Loss: 1.4624, Validation Accuracy: 0.5350


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 453/500, Loss: 0.7975, Accuracy: 0.7107
Validation Loss: 1.1675, Validation Accuracy: 0.5854


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 454/500, Loss: 0.7753, Accuracy: 0.7170
Validation Loss: 1.2125, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 455/500, Loss: 0.7589, Accuracy: 0.7121
Validation Loss: 1.3147, Validation Accuracy: 0.5238


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 456/500, Loss: 0.7681, Accuracy: 0.7170
Validation Loss: 1.3063, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 457/500, Loss: 0.7652, Accuracy: 0.7367
Validation Loss: 1.2312, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 458/500, Loss: 0.7898, Accuracy: 0.7177
Validation Loss: 1.2211, Validation Accuracy: 0.5770


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 459/500, Loss: 0.7938, Accuracy: 0.7135
Validation Loss: 1.2217, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.59s/it]


Epoch 460/500, Loss: 0.7848, Accuracy: 0.7100
Validation Loss: 1.2476, Validation Accuracy: 0.5714


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 461/500, Loss: 0.7622, Accuracy: 0.7275
Validation Loss: 1.2101, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 462/500, Loss: 0.7746, Accuracy: 0.7212
Validation Loss: 1.1942, Validation Accuracy: 0.5826


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 463/500, Loss: 0.7782, Accuracy: 0.7008
Validation Loss: 1.1912, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 464/500, Loss: 0.7571, Accuracy: 0.7331
Validation Loss: 1.2377, Validation Accuracy: 0.5854


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 465/500, Loss: 0.7471, Accuracy: 0.7324
Validation Loss: 1.3034, Validation Accuracy: 0.5378


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 466/500, Loss: 0.7427, Accuracy: 0.7338
Validation Loss: 1.2551, Validation Accuracy: 0.5574


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 467/500, Loss: 0.7981, Accuracy: 0.6994
Validation Loss: 1.2849, Validation Accuracy: 0.5574


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 468/500, Loss: 0.7823, Accuracy: 0.7065
Validation Loss: 1.2759, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 469/500, Loss: 0.7566, Accuracy: 0.7191
Validation Loss: 1.2511, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 470/500, Loss: 0.7555, Accuracy: 0.7205
Validation Loss: 1.2674, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 471/500, Loss: 0.7627, Accuracy: 0.7254
Validation Loss: 1.2606, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 472/500, Loss: 0.7473, Accuracy: 0.7275
Validation Loss: 1.2847, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 473/500, Loss: 0.8251, Accuracy: 0.6854
Validation Loss: 1.2547, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 474/500, Loss: 0.7593, Accuracy: 0.7233
Validation Loss: 1.3750, Validation Accuracy: 0.5518


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 475/500, Loss: 0.7524, Accuracy: 0.7240
Validation Loss: 1.1939, Validation Accuracy: 0.5994




New best model saved with accuracy: 0.5994


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 476/500, Loss: 0.7573, Accuracy: 0.7331
Validation Loss: 1.3260, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 477/500, Loss: 0.7358, Accuracy: 0.7374
Validation Loss: 1.3438, Validation Accuracy: 0.5574


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 478/500, Loss: 0.7960, Accuracy: 0.7086
Validation Loss: 1.4409, Validation Accuracy: 0.5546


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 479/500, Loss: 0.7872, Accuracy: 0.7184
Validation Loss: 1.4588, Validation Accuracy: 0.5378


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 480/500, Loss: 0.7426, Accuracy: 0.7437
Validation Loss: 1.4198, Validation Accuracy: 0.5378


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


Epoch 481/500, Loss: 0.7416, Accuracy: 0.7268
Validation Loss: 1.3955, Validation Accuracy: 0.5490


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 482/500, Loss: 0.7864, Accuracy: 0.7072
Validation Loss: 1.3515, Validation Accuracy: 0.5350


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 483/500, Loss: 0.7872, Accuracy: 0.7226
Validation Loss: 1.2875, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 484/500, Loss: 0.7772, Accuracy: 0.7149
Validation Loss: 1.2717, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


Epoch 485/500, Loss: 0.7261, Accuracy: 0.7423
Validation Loss: 1.2652, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 486/500, Loss: 0.7457, Accuracy: 0.7303
Validation Loss: 1.2254, Validation Accuracy: 0.5938


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 487/500, Loss: 0.7035, Accuracy: 0.7486
Validation Loss: 1.3693, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 488/500, Loss: 0.7279, Accuracy: 0.7353
Validation Loss: 1.3054, Validation Accuracy: 0.5630


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Epoch 489/500, Loss: 0.7558, Accuracy: 0.6987
Validation Loss: 1.2290, Validation Accuracy: 0.6078




New best model saved with accuracy: 0.6078


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 490/500, Loss: 0.7310, Accuracy: 0.7423
Validation Loss: 1.2980, Validation Accuracy: 0.5686


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 491/500, Loss: 0.6843, Accuracy: 0.7451
Validation Loss: 1.3162, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 492/500, Loss: 0.6846, Accuracy: 0.7521
Validation Loss: 1.3489, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


Epoch 493/500, Loss: 0.7241, Accuracy: 0.7310
Validation Loss: 1.3239, Validation Accuracy: 0.5462


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


Epoch 494/500, Loss: 0.7193, Accuracy: 0.7479
Validation Loss: 1.3889, Validation Accuracy: 0.5658


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 495/500, Loss: 0.7650, Accuracy: 0.7163
Validation Loss: 1.3342, Validation Accuracy: 0.5574


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 496/500, Loss: 0.7638, Accuracy: 0.7275
Validation Loss: 1.2893, Validation Accuracy: 0.5462


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 497/500, Loss: 0.7853, Accuracy: 0.7015
Validation Loss: 1.2749, Validation Accuracy: 0.5742


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


Epoch 498/500, Loss: 0.7495, Accuracy: 0.7310
Validation Loss: 1.2687, Validation Accuracy: 0.5910


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


Epoch 499/500, Loss: 0.7372, Accuracy: 0.7338
Validation Loss: 1.2822, Validation Accuracy: 0.5602


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


Epoch 500/500, Loss: 0.7152, Accuracy: 0.7381
Validation Loss: 1.3345, Validation Accuracy: 0.5882




In [3]:
print(best_val_acc)

0.6078431372549019


In [1]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel

def extract_participant_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    
    for part in parts:
        if 'participant' in part:
            participant_number = part.replace('participant', '')
            return int(participant_number)  # Convert to integer
    return None  # If no participant label found

# def extract_label_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'gesture' in part:
#             return int(part.replace('gesture', ''))
#     return None
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5part'  # Update with your folder path
num_epochs = 300
# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model_5part.pth'
last_model_path = 'last_model_5part.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    torch.save(model.state_dict(), last_model_path)

    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')


100%|██████████| 12/12 [00:28<00:00,  2.33s/it]


Epoch 270/300, Loss: 0.9957, Accuracy: 0.6124
Validation Loss: 1.3049, Validation Accuracy: 0.4986


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 271/300, Loss: 1.0249, Accuracy: 0.6180
Validation Loss: 1.3049, Validation Accuracy: 0.4986


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]


Epoch 272/300, Loss: 1.0060, Accuracy: 0.6145
Validation Loss: 1.3061, Validation Accuracy: 0.4930


100%|██████████| 12/12 [00:27<00:00,  2.31s/it]


Epoch 273/300, Loss: 1.0064, Accuracy: 0.6166
Validation Loss: 1.3122, Validation Accuracy: 0.4902


100%|██████████| 12/12 [00:28<00:00,  2.35s/it]


Epoch 274/300, Loss: 1.0244, Accuracy: 0.6131
Validation Loss: 1.3101, Validation Accuracy: 0.4986


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]


Epoch 275/300, Loss: 1.0134, Accuracy: 0.6187
Validation Loss: 1.2966, Validation Accuracy: 0.4902


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 276/300, Loss: 1.0130, Accuracy: 0.6110
Validation Loss: 1.3365, Validation Accuracy: 0.4902


100%|██████████| 12/12 [00:27<00:00,  2.31s/it]


Epoch 277/300, Loss: 0.9919, Accuracy: 0.6222
Validation Loss: 1.3243, Validation Accuracy: 0.4734


100%|██████████| 12/12 [00:27<00:00,  2.32s/it]


Epoch 278/300, Loss: 1.0184, Accuracy: 0.6117
Validation Loss: 1.3090, Validation Accuracy: 0.5098
New best model saved with accuracy: 0.5098


100%|██████████| 12/12 [00:28<00:00,  2.35s/it]


Epoch 279/300, Loss: 1.0215, Accuracy: 0.6067
Validation Loss: 1.2944, Validation Accuracy: 0.4958


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]


Epoch 280/300, Loss: 1.0271, Accuracy: 0.6067
Validation Loss: 1.3001, Validation Accuracy: 0.5126
New best model saved with accuracy: 0.5126


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 281/300, Loss: 1.0284, Accuracy: 0.6018
Validation Loss: 1.2993, Validation Accuracy: 0.5098


100%|██████████| 12/12 [00:27<00:00,  2.32s/it]


Epoch 282/300, Loss: 1.0174, Accuracy: 0.5962
Validation Loss: 1.3728, Validation Accuracy: 0.4650


100%|██████████| 12/12 [00:28<00:00,  2.36s/it]


Epoch 283/300, Loss: 1.0403, Accuracy: 0.5857
Validation Loss: 1.3607, Validation Accuracy: 0.4706


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 284/300, Loss: 1.0311, Accuracy: 0.5801
Validation Loss: 1.3261, Validation Accuracy: 0.4734


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]


Epoch 285/300, Loss: 1.0148, Accuracy: 0.6060
Validation Loss: 1.2961, Validation Accuracy: 0.4846


100%|██████████| 12/12 [00:28<00:00,  2.35s/it]


Epoch 286/300, Loss: 1.0244, Accuracy: 0.6124
Validation Loss: 1.3224, Validation Accuracy: 0.4566


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]


Epoch 287/300, Loss: 1.0048, Accuracy: 0.6103
Validation Loss: 1.3178, Validation Accuracy: 0.4846


100%|██████████| 12/12 [00:27<00:00,  2.31s/it]


Epoch 288/300, Loss: 1.0343, Accuracy: 0.5864
Validation Loss: 1.2792, Validation Accuracy: 0.4846


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 289/300, Loss: 1.0632, Accuracy: 0.5808
Validation Loss: 1.4034, Validation Accuracy: 0.4482


100%|██████████| 12/12 [00:28<00:00,  2.36s/it]


Epoch 290/300, Loss: 1.0643, Accuracy: 0.5737
Validation Loss: 1.2849, Validation Accuracy: 0.5126


100%|██████████| 12/12 [00:27<00:00,  2.33s/it]


Epoch 291/300, Loss: 1.0607, Accuracy: 0.5955
Validation Loss: 1.2930, Validation Accuracy: 0.4734


100%|██████████| 12/12 [00:27<00:00,  2.32s/it]


Epoch 292/300, Loss: 1.0284, Accuracy: 0.5920
Validation Loss: 1.2748, Validation Accuracy: 0.5070


  8%|▊         | 1/12 [00:03<00:39,  3.62s/it]


KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob

# Function to extract the participant ID from the filename
# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'participant' in part:
#             participant_id = part.replace('participant', '')
#             return int(participant_id)  # Convert to integer
#     return None  # If no participant ID found

# # Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

  from .autonotebook import tqdm as notebook_tqdm


In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.3, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTForImageClassification, ViTConfig
from tqdm import tqdm

# Define the number of classes (in this case, the number of participants)
num_classes = len(np.unique(labels))

config = ViTConfig(
    num_hidden_layers=6,  # Reduced number of layers
    hidden_size=512,      # Smaller hidden size
    num_attention_heads=4,  # Fewer attention heads
    patch_size=16,        # Smaller patch size
    intermediate_size=1024,  # Reduced intermediate layer size
    image_size=224,       # Input image size
    num_labels=num_classes,         # Number of output classes
    hidden_dropout_prob=0.3,  # Adding dropout to regularize
    attention_probs_dropout_prob=0.3
)

# Initialize ViT model from scratch with this config
model = ViTForImageClassification(config)

# Define optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 30  # Set the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader):
        # Move data to the same device as the model
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images).logits
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # Evaluation after every epoch (optional)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy after epoch {epoch+1}: {accuracy:.2f}%")

# Save the trained model
torch.save(model.state_dict(), 'vit_emg_model.pth')


100%|██████████| 78/78 [01:36<00:00,  1.23s/it]


Epoch [1/100], Loss: 2.9389
Test Accuracy after epoch 1: 5.43%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [2/100], Loss: 2.8926
Test Accuracy after epoch 2: 6.64%


100%|██████████| 78/78 [00:40<00:00,  1.90it/s]


Epoch [3/100], Loss: 2.8774
Test Accuracy after epoch 3: 4.58%


100%|██████████| 78/78 [00:39<00:00,  1.95it/s]


Epoch [4/100], Loss: 2.8714
Test Accuracy after epoch 4: 6.36%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [5/100], Loss: 2.8526
Test Accuracy after epoch 5: 7.30%


100%|██████████| 78/78 [00:40<00:00,  1.91it/s]


Epoch [6/100], Loss: 2.8157
Test Accuracy after epoch 6: 6.17%


100%|██████████| 78/78 [00:40<00:00,  1.90it/s]


Epoch [7/100], Loss: 2.8185
Test Accuracy after epoch 7: 6.83%


100%|██████████| 78/78 [00:40<00:00,  1.93it/s]


Epoch [8/100], Loss: 2.7884
Test Accuracy after epoch 8: 5.61%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [9/100], Loss: 2.7815
Test Accuracy after epoch 9: 5.43%


 87%|████████▋ | 68/78 [00:35<00:05,  1.91it/s]


KeyboardInterrupt: 

In [None]:
print(f"Model output classes: {model.config.num_labels}")
print(f"Unique labels in dataset: {np.unique(train_labels)}")

Model output classes: 10
Unique labels in dataset: [0 1 2 3 4 5 6 7 8 9]


In [None]:
# --- Training loop for 50 epochs ---
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 1, Loss: 2.2416, Accuracy: 0.1794
Test accuracy: 0.2553


Epoch 2/50: 100%|██████████| 90/90 [01:12<00:00,  1.25it/s]


Epoch 2, Loss: 1.8933, Accuracy: 0.3556
Test accuracy: 0.3955


Epoch 3/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 3, Loss: 1.6035, Accuracy: 0.4693
Test accuracy: 0.4320


Epoch 4/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 4, Loss: 1.3069, Accuracy: 0.5781
Test accuracy: 0.4600


Epoch 5/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 5, Loss: 1.0768, Accuracy: 0.6511
Test accuracy: 0.5470


Epoch 6/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 6, Loss: 0.6955, Accuracy: 0.7943
Test accuracy: 0.5330


Epoch 7/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 7, Loss: 0.4704, Accuracy: 0.8635
Test accuracy: 0.5259


Epoch 8/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 8, Loss: 0.3046, Accuracy: 0.9210
Test accuracy: 0.5498


Epoch 9/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 9, Loss: 0.2185, Accuracy: 0.9438
Test accuracy: 0.5512


Epoch 10/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 10, Loss: 0.1618, Accuracy: 0.9582
Test accuracy: 0.4867


Epoch 11/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 11, Loss: 0.1546, Accuracy: 0.9596
Test accuracy: 0.5456


Epoch 12/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 12, Loss: 0.1435, Accuracy: 0.9603
Test accuracy: 0.5652


Epoch 13/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 13, Loss: 0.0731, Accuracy: 0.9849
Test accuracy: 0.4923


Epoch 14/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 14, Loss: 0.1267, Accuracy: 0.9645
Test accuracy: 0.4979


Epoch 15/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 15, Loss: 0.0674, Accuracy: 0.9835
Test accuracy: 0.5302


Epoch 16/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 16, Loss: 0.0767, Accuracy: 0.9796
Test accuracy: 0.5666


Epoch 17/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 17, Loss: 0.0595, Accuracy: 0.9849
Test accuracy: 0.5891


Epoch 18/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 18, Loss: 0.0322, Accuracy: 0.9940
Test accuracy: 0.5428


Epoch 19/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 19, Loss: 0.0548, Accuracy: 0.9839
Test accuracy: 0.5680


Epoch 20/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 20, Loss: 0.0505, Accuracy: 0.9856
Test accuracy: 0.5666


Epoch 21/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 21, Loss: 0.0864, Accuracy: 0.9744
Test accuracy: 0.5722


Epoch 22/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 22, Loss: 0.1016, Accuracy: 0.9698
Test accuracy: 0.4769


Epoch 23/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 23, Loss: 0.0759, Accuracy: 0.9772
Test accuracy: 0.5330


Epoch 24/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 24, Loss: 0.0339, Accuracy: 0.9933
Test accuracy: 0.5498


Epoch 25/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 25, Loss: 0.1114, Accuracy: 0.9705
Test accuracy: 0.5652


Epoch 26/50: 100%|██████████| 90/90 [01:11<00:00,  1.27it/s]


Epoch 26, Loss: 0.0147, Accuracy: 0.9972
Test accuracy: 0.5863


Epoch 27/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 27, Loss: 0.0083, Accuracy: 0.9993
Test accuracy: 0.5806


Epoch 28/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 28, Loss: 0.0066, Accuracy: 0.9993
Test accuracy: 0.5849


Epoch 29/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 29, Loss: 0.0051, Accuracy: 0.9996
Test accuracy: 0.5849


Epoch 30/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 30, Loss: 0.0035, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 31/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 31, Loss: 0.0031, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 32/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 32, Loss: 0.0028, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 33/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 33, Loss: 0.0026, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 34/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 34, Loss: 0.0024, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 35/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 35, Loss: 0.0022, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 36/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 36, Loss: 0.0021, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 37/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 37, Loss: 0.0019, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 38/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 38, Loss: 0.0018, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 39/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 39, Loss: 0.0017, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 40/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 40, Loss: 0.0016, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 41/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 41, Loss: 0.0015, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 42/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 42, Loss: 0.0014, Accuracy: 1.0000
Test accuracy: 0.5933


Epoch 43/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 43, Loss: 0.0013, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 44/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 44, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 45/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 45, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 46/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 46, Loss: 0.0011, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 47/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 47, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 48/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 48, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 49/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 49, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 50/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 50, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5975


In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 45/45 [00:36<00:00,  1.25it/s]


Epoch 1, Loss: 1.4741, Accuracy: 0.3455
Test accuracy: 0.4566


Epoch 2/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 2, Loss: 1.1747, Accuracy: 0.5407
Test accuracy: 0.5770


Epoch 3/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 3, Loss: 0.8848, Accuracy: 0.6713
Test accuracy: 0.6751


Epoch 4/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 4, Loss: 0.6088, Accuracy: 0.8062
Test accuracy: 0.7171


Epoch 5/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 5, Loss: 0.4332, Accuracy: 0.8581
Test accuracy: 0.6303


Epoch 6/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 6, Loss: 0.2952, Accuracy: 0.9094
Test accuracy: 0.7423


Epoch 7/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 7, Loss: 0.1660, Accuracy: 0.9565
Test accuracy: 0.7171


Epoch 8/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 8, Loss: 0.1073, Accuracy: 0.9719
Test accuracy: 0.7003


Epoch 9/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 9, Loss: 0.1327, Accuracy: 0.9621
Test accuracy: 0.7283


Epoch 10/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 10, Loss: 0.0993, Accuracy: 0.9747
Test accuracy: 0.7563


Epoch 11/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 11, Loss: 0.0724, Accuracy: 0.9831
Test accuracy: 0.6947


Epoch 12/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 12, Loss: 0.0463, Accuracy: 0.9902
Test accuracy: 0.7115


Epoch 13/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 13, Loss: 0.0474, Accuracy: 0.9881
Test accuracy: 0.7283


Epoch 14/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 14, Loss: 0.0936, Accuracy: 0.9747
Test accuracy: 0.7115


Epoch 15/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 15, Loss: 0.1236, Accuracy: 0.9621
Test accuracy: 0.6499


Epoch 16/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 16, Loss: 0.1496, Accuracy: 0.9537
Test accuracy: 0.6779


Epoch 17/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 17, Loss: 0.0414, Accuracy: 0.9909
Test accuracy: 0.7619


Epoch 18/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 18, Loss: 0.0188, Accuracy: 0.9965
Test accuracy: 0.7759


Epoch 19/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 19, Loss: 0.0172, Accuracy: 0.9986
Test accuracy: 0.7731


Epoch 20/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 20, Loss: 0.0108, Accuracy: 0.9993
Test accuracy: 0.7675


In [None]:
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data    

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 113/113 [03:05<00:00,  1.64s/it]


Epoch 1, Loss: 1.0156, Accuracy: 0.5933
Test accuracy: 0.6951


Epoch 2/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 2, Loss: 0.6677, Accuracy: 0.7506
Test accuracy: 0.7816


Epoch 3/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 3, Loss: 0.4857, Accuracy: 0.8227
Test accuracy: 0.7627


Epoch 4/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 4, Loss: 0.3735, Accuracy: 0.8713
Test accuracy: 0.8104


Epoch 5/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 5, Loss: 0.2388, Accuracy: 0.9223
Test accuracy: 0.7938


Epoch 6/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 6, Loss: 0.1578, Accuracy: 0.9503
Test accuracy: 0.7993


Epoch 7/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 7, Loss: 0.1318, Accuracy: 0.9603
Test accuracy: 0.7860


Epoch 8/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 8, Loss: 0.0938, Accuracy: 0.9734
Test accuracy: 0.8226


Epoch 9/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 9, Loss: 0.0645, Accuracy: 0.9803
Test accuracy: 0.7672


Epoch 10/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 10, Loss: 0.0746, Accuracy: 0.9761
Test accuracy: 0.7960


Epoch 11/20: 100%|██████████| 113/113 [01:29<00:00,  1.26it/s]


Epoch 11, Loss: 0.0330, Accuracy: 0.9917
Test accuracy: 0.8149


Epoch 12/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 12, Loss: 0.0445, Accuracy: 0.9889
Test accuracy: 0.8104


Epoch 13/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 13, Loss: 0.0661, Accuracy: 0.9792
Test accuracy: 0.8115


Epoch 14/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 14, Loss: 0.0471, Accuracy: 0.9850
Test accuracy: 0.8259


Epoch 15/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 15, Loss: 0.0220, Accuracy: 0.9945
Test accuracy: 0.8082


Epoch 16/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 16, Loss: 0.0149, Accuracy: 0.9969
Test accuracy: 0.8193


Epoch 17/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 17, Loss: 0.0469, Accuracy: 0.9859
Test accuracy: 0.7971


Epoch 18/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 18, Loss: 0.0401, Accuracy: 0.9870
Test accuracy: 0.7716


Epoch 19/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 19, Loss: 0.0242, Accuracy: 0.9931
Test accuracy: 0.8049


Epoch 20/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 20, Loss: 0.0091, Accuracy: 0.9978
Test accuracy: 0.8060


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 226/226 [06:29<00:00,  1.72s/it]


Epoch 1, Loss: 1.6487, Accuracy: 0.4242
Test accuracy: 0.5774


Epoch 2/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 2, Loss: 1.1021, Accuracy: 0.6264
Test accuracy: 0.6051


Epoch 3/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 3, Loss: 0.8969, Accuracy: 0.6896
Test accuracy: 0.6794


Epoch 4/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 4, Loss: 0.6698, Accuracy: 0.7712
Test accuracy: 0.6589


Epoch 5/20: 100%|██████████| 226/226 [02:58<00:00,  1.26it/s]


Epoch 5, Loss: 0.5011, Accuracy: 0.8303
Test accuracy: 0.6744


Epoch 6/20: 100%|██████████| 226/226 [02:58<00:00,  1.27it/s]


Epoch 6, Loss: 0.3302, Accuracy: 0.8935
Test accuracy: 0.6384


Epoch 7/20:  92%|█████████▏| 208/226 [02:45<00:14,  1.26it/s]


KeyboardInterrupt: 

In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import glob

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

In [None]:

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_20part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
all_labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)
print(f"Shape of X_train: {X_train.shape}")

Shape of X_train: (5701, 327680)


In [None]:
pca = PCA(n_components = 2000)
# Step 2: Fit PCA on the dataset
pca.fit(X_train)
# Step 3: Calculate the cumulative variance explained by the components
cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
# Step 4: Find the number of components that explain at least 95% variance
n_components_95 = np.argmax(cumulative_variance >= 0.95) + 1

print(f"Number of components to retain 95% of variance: {n_components_95}")

# Print explained variance for each component
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance by each component: {explained_variance}")

# Calculate cumulative explained variance
cumulative_variance = np.cumsum(explained_variance)
print(f"Cumulative explained variance: {cumulative_variance}")

Number of components to retain 95% of variance: 1
Explained variance by each component: [0.02734777 0.02461403 0.00818816 0.0069905  0.00584958 0.00445974
 0.00403911 0.00369412 0.00349787 0.00316976 0.00300401 0.00289899
 0.00279481 0.00275903 0.00264407 0.00261821 0.00259447 0.00257349
 0.00248961 0.00239943 0.0023721  0.00232195 0.00227801 0.00223757
 0.0022159  0.00220089 0.00212229 0.00206722 0.00201354 0.00198255
 0.00194987 0.00192753 0.00191069 0.00189586 0.00187392 0.00184805
 0.00183598 0.00181051 0.00178506 0.00175982 0.0017539  0.00172137
 0.00170902 0.00169332 0.00167088 0.00165842 0.00165174 0.0016362
 0.00163023 0.00160764 0.00159447 0.00155964 0.00155086 0.00152713
 0.00146956 0.00145834 0.00141764 0.00141322 0.00140813 0.00138163
 0.00136974 0.00136623 0.00134497 0.00133781 0.00131482 0.00130781
 0.00129883 0.00128668 0.00127945 0.00126588 0.00125526 0.00123351
 0.00122743 0.00121431 0.00120742 0.00119956 0.00119493 0.00118896
 0.0011854  0.00117166 0.00116383 0.001154

In [None]:


# Initialize PCA for dimensionality reduction
#pca = PCA(n_components=4)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.3065


In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

# Specify session folders
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']

# Target length for EMG signals (adjust based on your data)
target_length = 10240

# Collect all .mat files and their corresponding labels
all_mat_files = []
all_labels = []

for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=100)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.4065


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, GridSearchCV
from tqdm import tqdm

# Function to extract label from filename
def extract_label_from_filename(filename):
    basename = os.path.basename(filename)
    return int(basename.split('_')[2].replace('gesture', ''))

# Function to process a batch of files
def process_batch(batch_files, target_length):
    data_emg_list = []
    labels_list = []
    for mat_file in tqdm(batch_files, desc="Processing Batch"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(mat_file)
            data_emg = mat_data['data_emg']
            label = extract_label_from_filename(mat_file)

            # Handle NaNs and Infinities
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Pad or truncate the data to the target length
            if data.shape[0] < target_length:
                padded_data = np.pad(data, ((0, target_length - data.shape[0]), (0, 0)), 'constant', constant_values=0)
            else:
                padded_data = data[:target_length]
            
            # Append to lists
            data_emg_list.append(padded_data)
            labels_list.append(label)
    
    # Convert to numpy arrays
    X_batch = np.array(data_emg_list)
    y_batch = np.array(labels_list)
    
    # Reshape the data
    X_batch_flat = X_batch.reshape(X_batch.shape[0], -1)
    
    return X_batch_flat, y_batch

# Path to the data
target_length = 10240  # Desired length for each EMG signal
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']
all_mat_files = []
all_labels = []

# Collect all .mat files and their corresponding labels
for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=4)  # Reduce to 100 components, adjust as needed

# Initialize an SVM with stochastic gradient descent (SGD)
svm_model = SGDClassifier(loss='hinge', penalty='l2', max_iter=1000)

# Process training data in batches
batch_size = 50

for i in range(0, len(train_files), batch_size):
    batch_files = train_files[i:i+batch_size]
    X_batch, y_batch = process_batch(batch_files, target_length)
    
    # Apply PCA
    X_batch_pca = pca.fit_transform(X_batch)
    
    # Incrementally fit the SVM model, passing the complete set of classes every time
    svm_model.partial_fit(X_batch_pca, y_batch, classes=np.unique(train_labels))

# Process and test the model on test data
X_test, y_test = process_batch(test_files, target_length)
X_test_pca = pca.transform(X_test)

# Inference on test data
y_pred = svm_model.predict(X_test_pca)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# --- Hyperparameter Tuning with GridSearchCV ---

# Define the parameter grid for tuning
param_grid = {
    'alpha': [0.0001, 0.001, 0.01],
    'loss': ['hinge', 'log'],
    'penalty': ['l2', 'l1', 'elasticnet'],
    'max_iter': [1000, 2000, 3000]
}

# Initialize GridSearchCV with SGDClassifier
grid_search = GridSearchCV(SGDClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)

# Fit grid search
grid_search.fit(X_test_pca, y_test)

# Print best parameters and score
print(f"Best parameters found: {grid_search.best_params_}")
print(f"Best cross-validation accuracy: {grid_search.best_score_ * 100:.2f}%")


Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.05it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.17it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.18it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.02it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.53it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.70it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.16it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.64it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.37it/s]
Processing

Test Accuracy: 5.94%
Fitting 5 folds for each of 54 candidates, totalling 270 fits
Best parameters found: {'alpha': 0.01, 'loss': 'hinge', 'max_iter': 3000, 'penalty': 'l2'}
Best cross-validation accuracy: 11.03%


135 fits failed out of a total of 270.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
7 fits failed with the following error:
Traceback (most recent call last):
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 1466, in wrapper
    estimator._validate_params()
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packag

In [None]:
import os

def extract_label_from_filename(filepath):
    # Get the base name of the file (e.g., 'session1_participant1_gesture10_trial1.dat')
    filename = os.path.basename(filepath)
    
    # Split the filename by underscores
    parts = filename.split('_')
    
    # Find the part that contains 'gesture' and extract the gesture number
    for part in parts:
        if 'gesture' in part:
            # Assuming the format is 'gesture<number>', we can remove 'gesture' and get the number
            gesture_number = part.replace('gesture', '')
            return int(gesture_number)  # Convert to integer if needed
    
    return None  # If no gesture label found


In [None]:
import os
import scipy.io as sio
import numpy as np
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm  # Import tqdm for progress bars

# Path where the .mat files are stored
#data_folder = 'Session1'

# Initialize lists to hold data
data_emg_list = []
labels_list = []  # Assuming you have labels corresponding to the gestures

# Loop through each .mat file in the folder with a progress bar
for session_num in range(1, 3):  # Looping through sessions (1 to 3)
    session_folder = os.path.join(f'Session{session_num}_matfiles')
    
    # Get all .mat files in the current session folder
    mat_files = sorted(os.listdir(session_folder))
    
    # Use tqdm to show progress
    for mat_file in tqdm(mat_files, desc=f"Processing Session {session_num}"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(os.path.join(session_folder, mat_file))
            
            # Extract the EMG data (you may need to adjust variable names if different)
            data_emg = mat_data['data_emg']  # This should match the name in your .mat files
            label = extract_label_from_filename(mat_file)  # You would need to define this function
            # Check for NaNs and infinities in the loaded data
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            # Append the data and label to lists
            data_emg_list.append(data)
            labels_list.append(label)


Processing Session 1: 100%|██████████| 5107/5107 [02:12<00:00, 38.46it/s]
Processing Session 2: 100%|██████████| 5109/5109 [02:24<00:00, 35.36it/s]


In [None]:
target_length = 10240  # The desired length for all samples
padded_data_emg_list = [
    np.pad(emg_data, ((0, target_length - emg_data.shape[0]), (0, 0)), 'constant', constant_values=0)
    if emg_data.shape[0] < target_length else emg_data
    for emg_data in data_emg_list
]

In [None]:

# Convert lists to numpy arrays
X = np.array(padded_data_emg_list)  # Data (e.g., EMG signals)
y = np.array(labels_list)  # Labels (e.g., gesture IDs)

# Flatten the data if needed (depends on the structure of your data)
X_flat = X.reshape(X.shape[0], -1)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_flat, y, test_size=0.2, random_state=42)

# Apply PCA
print("Applying PCA...")
pca = PCA(n_components=50)  # Adjust the number of components as needed
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# Train an SVM model with progress tracking
print("Training SVM...")
svm_model = SVC(kernel='linear')

# You can wrap the fitting process with tqdm to monitor the training as well
svm_model.fit(X_train_pca, y_train)

# Make predictions and evaluate the model
y_pred = svm_model.predict(X_test_pca)
accuracy = accuracy_score(y_test, y_pred)
print(f'PCA + SVM Accuracy: {accuracy:.4f}')

# Optionally: Use X_train_pca as input to a Vision Transformer model
# You can now use the X_train_pca for further processing or to feed into a Vision Transformer model.


Applying PCA...


In [None]:
# nan_count = np.isnan(data_emg).sum()
# inf_count = np.isinf(data_emg).sum()
# print(f"NaNs found: {nan_count}, Infinities found: {inf_count}")