In [1]:
import numpy as np

# Apply Gaussian noise
def add_noise(emg_data, noise_factor=0.05):
    noise = np.random.randn(*emg_data.shape) * noise_factor
    augmented_data = emg_data + noise
    return augmented_data

# Time warping: Stretch or compress the time axis slightly
def time_warp(emg_data, time_warp_factor=0.1):
    stretch_factor = np.random.uniform(1 - time_warp_factor, 1 + time_warp_factor)
    indices = np.round(np.linspace(0, emg_data.shape[0] - 1, int(emg_data.shape[0] * stretch_factor))).astype(int)
    augmented_data = emg_data[indices % emg_data.shape[0]]
    return augmented_data

# Signal scaling: Multiply by a random factor
def scale_signal(emg_data, scale_factor=0.1):
    scaling_factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
    return emg_data * scaling_factor

# Random cropping and padding
def random_crop_pad(emg_data, target_length):
    if emg_data.shape[0] < target_length:
        # Pad
        pad_size = target_length - emg_data.shape[0]
        pad_before = np.random.randint(0, pad_size)
        pad_after = pad_size - pad_before
        augmented_data = np.pad(emg_data, ((pad_before, pad_after), (0, 0)), 'constant')
    else:
        # Crop
        crop_start = np.random.randint(0, emg_data.shape[0] - target_length)
        augmented_data = emg_data[crop_start:crop_start + target_length]
    
    return augmented_data

# Random horizontal or vertical flip
def flip(emg_data):
    if np.random.rand() > 0.5:
        return np.flip(emg_data, axis=0)  # Flip along time axis
    return emg_data

# Data augmentation pipeline
def augment_data(emg_data, target_length):
    # Apply augmentations
    emg_data = add_noise(emg_data)
    #emg_data = time_warp(emg_data)
    emg_data = scale_signal(emg_data)
    #emg_data = random_crop_pad(emg_data, target_length)
    #emg_data = flip(emg_data)
    
    return emg_data


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel
import mlflow
import mlflow.pytorch

root_folder = 's1_s2_s3_matfiles_10part'  # Update with your folder path
num_epochs = 500

# Track experiment with MLflow
def start_mlflow_experiment(experiment_name):
    mlflow.set_experiment(experiment_name)
    mlflow.start_run()

def end_mlflow_experiment():
    mlflow.end_run()

# Define the name of the experiment based on the input file or another identifier
def get_experiment_name_from_file(filepath):
    experiment_name = os.path.basename(filepath).split('.')[0]  # Extract filename without extension
    return experiment_name

# Initialize MLflow experiment (This can be placed at the start of your main function)
experiment_name = get_experiment_name_from_file(root_folder)  # Using folder as experiment name
start_mlflow_experiment(experiment_name)

def extract_participant_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    
    for part in parts:
        if 'participant' in part:
            participant_number = part.replace('participant', '')
            return int(participant_number)  # Convert to integer
    return None  # If no participant label found

# def extract_label_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'gesture' in part:
#             return int(part.replace('gesture', ''))
#     return None

class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model_10part.pth'
last_model_path = 'last_model_10part.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    mlflow.log_metric("train_loss", epoch_loss, step=epoch)
    mlflow.log_metric("train_accuracy", epoch_acc, step=epoch)

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    mlflow.log_metric("val_loss", val_loss, step=epoch)
    mlflow.log_metric("val_accuracy", val_acc, step=epoch)

    torch.save(model.state_dict(), last_model_path)
    mlflow.pytorch.log_model(model, "models/last_model")
    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')
        mlflow.pytorch.log_model(model, "models/best_model")

end_mlflow_experiment()        

  from .autonotebook import tqdm as notebook_tqdm
2024/09/21 10:35:54 INFO mlflow.tracking.fluent: Experiment with name 's1_s2_s3_matfiles_10part' does not exist. Creating a new experiment.


cuda


100%|██████████| 12/12 [01:57<00:00,  9.77s/it]


Epoch 1/500, Loss: 2.3236, Accuracy: 0.0951
Validation Loss: 2.3129, Validation Accuracy: 0.0996




New best model saved with accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 2/500, Loss: 2.3173, Accuracy: 0.0997
Validation Loss: 2.3054, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 3/500, Loss: 2.3146, Accuracy: 0.0955
Validation Loss: 2.3038, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 4/500, Loss: 2.3048, Accuracy: 0.1004
Validation Loss: 2.3038, Validation Accuracy: 0.1010




New best model saved with accuracy: 0.1010


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 5/500, Loss: 2.3108, Accuracy: 0.0969
Validation Loss: 2.3027, Validation Accuracy: 0.1052




New best model saved with accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 6/500, Loss: 2.3045, Accuracy: 0.1018
Validation Loss: 2.3025, Validation Accuracy: 0.1066




New best model saved with accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 7/500, Loss: 2.3057, Accuracy: 0.1042
Validation Loss: 2.3029, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 8/500, Loss: 2.3054, Accuracy: 0.1004
Validation Loss: 2.3032, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 9/500, Loss: 2.3089, Accuracy: 0.1085
Validation Loss: 2.3028, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 10/500, Loss: 2.3110, Accuracy: 0.0941
Validation Loss: 2.3040, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 11/500, Loss: 2.3074, Accuracy: 0.1053
Validation Loss: 2.3028, Validation Accuracy: 0.1122




New best model saved with accuracy: 0.1122


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 12/500, Loss: 2.3049, Accuracy: 0.1057
Validation Loss: 2.3022, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 13/500, Loss: 2.3102, Accuracy: 0.0955
Validation Loss: 2.3029, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 14/500, Loss: 2.3073, Accuracy: 0.0948
Validation Loss: 2.3018, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 15/500, Loss: 2.3049, Accuracy: 0.0983
Validation Loss: 2.3017, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 16/500, Loss: 2.3045, Accuracy: 0.1053
Validation Loss: 2.3020, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 17/500, Loss: 2.3030, Accuracy: 0.1155
Validation Loss: 2.3015, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 18/500, Loss: 2.3013, Accuracy: 0.1074
Validation Loss: 2.3011, Validation Accuracy: 0.0982


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 19/500, Loss: 2.3061, Accuracy: 0.0997
Validation Loss: 2.3001, Validation Accuracy: 0.1206




New best model saved with accuracy: 0.1206


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 20/500, Loss: 2.3060, Accuracy: 0.0951
Validation Loss: 2.3010, Validation Accuracy: 0.1234




New best model saved with accuracy: 0.1234


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 21/500, Loss: 2.3059, Accuracy: 0.0976
Validation Loss: 2.3027, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 22/500, Loss: 2.3024, Accuracy: 0.1028
Validation Loss: 2.2988, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 23/500, Loss: 2.3022, Accuracy: 0.1053
Validation Loss: 2.3009, Validation Accuracy: 0.0982


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 24/500, Loss: 2.3033, Accuracy: 0.0944
Validation Loss: 2.2988, Validation Accuracy: 0.1136


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 25/500, Loss: 2.3023, Accuracy: 0.1092
Validation Loss: 2.2992, Validation Accuracy: 0.0968


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 26/500, Loss: 2.3038, Accuracy: 0.1060
Validation Loss: 2.3004, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 27/500, Loss: 2.3019, Accuracy: 0.1053
Validation Loss: 2.2957, Validation Accuracy: 0.1150


100%|██████████| 12/12 [00:54<00:00,  4.52s/it]


Epoch 28/500, Loss: 2.2988, Accuracy: 0.1144
Validation Loss: 2.3036, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 29/500, Loss: 2.2985, Accuracy: 0.1169
Validation Loss: 2.2988, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 30/500, Loss: 2.2972, Accuracy: 0.1113
Validation Loss: 2.3059, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 31/500, Loss: 2.2971, Accuracy: 0.1221
Validation Loss: 2.3002, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 32/500, Loss: 2.2989, Accuracy: 0.1162
Validation Loss: 2.3073, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 33/500, Loss: 2.2948, Accuracy: 0.1316
Validation Loss: 2.3058, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 34/500, Loss: 2.2965, Accuracy: 0.1134
Validation Loss: 2.3097, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 35/500, Loss: 2.2945, Accuracy: 0.1218
Validation Loss: 2.3088, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 36/500, Loss: 2.2962, Accuracy: 0.1172
Validation Loss: 2.3080, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 37/500, Loss: 2.2949, Accuracy: 0.1271
Validation Loss: 2.3064, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 38/500, Loss: 2.2950, Accuracy: 0.1221
Validation Loss: 2.3102, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 39/500, Loss: 2.2968, Accuracy: 0.1137
Validation Loss: 2.3097, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 40/500, Loss: 2.2980, Accuracy: 0.1099
Validation Loss: 2.3105, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 41/500, Loss: 2.2927, Accuracy: 0.1221
Validation Loss: 2.3107, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 42/500, Loss: 2.2965, Accuracy: 0.1250
Validation Loss: 2.3107, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 43/500, Loss: 2.2973, Accuracy: 0.1113
Validation Loss: 2.3107, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 44/500, Loss: 2.2948, Accuracy: 0.1134
Validation Loss: 2.3108, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 45/500, Loss: 2.2960, Accuracy: 0.1257
Validation Loss: 2.3108, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 46/500, Loss: 2.2974, Accuracy: 0.1197
Validation Loss: 2.3100, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 47/500, Loss: 2.2967, Accuracy: 0.1120
Validation Loss: 2.3089, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 48/500, Loss: 2.2943, Accuracy: 0.1309
Validation Loss: 2.3105, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 49/500, Loss: 2.2957, Accuracy: 0.1253
Validation Loss: 2.3105, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 50/500, Loss: 2.2979, Accuracy: 0.1257
Validation Loss: 2.3155, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 51/500, Loss: 2.2975, Accuracy: 0.1246
Validation Loss: 2.3185, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 52/500, Loss: 2.2917, Accuracy: 0.1288
Validation Loss: 2.3037, Validation Accuracy: 0.1150


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 53/500, Loss: 2.2934, Accuracy: 0.1193
Validation Loss: 2.3289, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 54/500, Loss: 2.2940, Accuracy: 0.1214
Validation Loss: 2.3052, Validation Accuracy: 0.1248




New best model saved with accuracy: 0.1248


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 55/500, Loss: 2.2949, Accuracy: 0.1186
Validation Loss: 2.3062, Validation Accuracy: 0.1164


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 56/500, Loss: 2.2918, Accuracy: 0.1165
Validation Loss: 2.3334, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 57/500, Loss: 2.2932, Accuracy: 0.1236
Validation Loss: 2.3012, Validation Accuracy: 0.1234


100%|██████████| 12/12 [00:54<00:00,  4.52s/it]


Epoch 58/500, Loss: 2.2904, Accuracy: 0.1253
Validation Loss: 2.2953, Validation Accuracy: 0.1262




New best model saved with accuracy: 0.1262


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 59/500, Loss: 2.2925, Accuracy: 0.1172
Validation Loss: 2.2723, Validation Accuracy: 0.1515




New best model saved with accuracy: 0.1515


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 60/500, Loss: 2.3043, Accuracy: 0.1123
Validation Loss: 2.3143, Validation Accuracy: 0.1220


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 61/500, Loss: 2.2971, Accuracy: 0.1186
Validation Loss: 2.3197, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 62/500, Loss: 2.2928, Accuracy: 0.1102
Validation Loss: 2.3220, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 63/500, Loss: 2.2862, Accuracy: 0.1281
Validation Loss: 2.3554, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.52s/it]


Epoch 64/500, Loss: 2.2895, Accuracy: 0.1211
Validation Loss: 2.2636, Validation Accuracy: 0.1360


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 65/500, Loss: 2.3045, Accuracy: 0.1144
Validation Loss: 2.2744, Validation Accuracy: 0.1290


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 66/500, Loss: 2.2967, Accuracy: 0.1176
Validation Loss: 2.2808, Validation Accuracy: 0.1388


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 67/500, Loss: 2.2891, Accuracy: 0.1292
Validation Loss: 2.3387, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 68/500, Loss: 2.2829, Accuracy: 0.1323
Validation Loss: 2.3777, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 69/500, Loss: 2.3018, Accuracy: 0.1229
Validation Loss: 2.3356, Validation Accuracy: 0.1178


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 70/500, Loss: 2.2833, Accuracy: 0.1313
Validation Loss: 2.3643, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 71/500, Loss: 2.2771, Accuracy: 0.1176
Validation Loss: 2.2564, Validation Accuracy: 0.1431


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 72/500, Loss: 2.3004, Accuracy: 0.1193
Validation Loss: 2.3555, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 73/500, Loss: 2.2913, Accuracy: 0.1229
Validation Loss: 2.3636, Validation Accuracy: 0.1066


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 74/500, Loss: 2.3098, Accuracy: 0.1162
Validation Loss: 2.2664, Validation Accuracy: 0.1403


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 75/500, Loss: 2.2848, Accuracy: 0.1372
Validation Loss: 2.3832, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 76/500, Loss: 2.3003, Accuracy: 0.1197
Validation Loss: 2.2578, Validation Accuracy: 0.1557




New best model saved with accuracy: 0.1557


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 77/500, Loss: 2.2913, Accuracy: 0.1260
Validation Loss: 2.3451, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 78/500, Loss: 2.2785, Accuracy: 0.1323
Validation Loss: 2.3952, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 79/500, Loss: 2.3186, Accuracy: 0.1162
Validation Loss: 2.4140, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 80/500, Loss: 2.2820, Accuracy: 0.1376
Validation Loss: 2.3151, Validation Accuracy: 0.1164


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 81/500, Loss: 2.2693, Accuracy: 0.1365
Validation Loss: 2.2976, Validation Accuracy: 0.1332


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 82/500, Loss: 2.2506, Accuracy: 0.1506
Validation Loss: 2.2828, Validation Accuracy: 0.1711




New best model saved with accuracy: 0.1711


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 83/500, Loss: 2.2514, Accuracy: 0.1537
Validation Loss: 2.2909, Validation Accuracy: 0.1332


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 84/500, Loss: 2.2449, Accuracy: 0.1580
Validation Loss: 2.2415, Validation Accuracy: 0.1445


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 85/500, Loss: 2.2575, Accuracy: 0.1362
Validation Loss: 2.2635, Validation Accuracy: 0.1346


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 86/500, Loss: 2.2356, Accuracy: 0.1474
Validation Loss: 2.3213, Validation Accuracy: 0.1445


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 87/500, Loss: 2.2575, Accuracy: 0.1358
Validation Loss: 2.3769, Validation Accuracy: 0.1136


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 88/500, Loss: 2.3312, Accuracy: 0.1200
Validation Loss: 2.3578, Validation Accuracy: 0.1178


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 89/500, Loss: 2.2393, Accuracy: 0.1471
Validation Loss: 2.2354, Validation Accuracy: 0.1543


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 90/500, Loss: 2.2626, Accuracy: 0.1351
Validation Loss: 2.3493, Validation Accuracy: 0.1248


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 91/500, Loss: 2.2500, Accuracy: 0.1509
Validation Loss: 2.3605, Validation Accuracy: 0.1262


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 92/500, Loss: 2.2404, Accuracy: 0.1551
Validation Loss: 2.2528, Validation Accuracy: 0.1627


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 93/500, Loss: 2.2488, Accuracy: 0.1537
Validation Loss: 2.2784, Validation Accuracy: 0.1473


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 94/500, Loss: 2.2414, Accuracy: 0.1548
Validation Loss: 2.2322, Validation Accuracy: 0.1781




New best model saved with accuracy: 0.1781


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 95/500, Loss: 2.2542, Accuracy: 0.1495
Validation Loss: 2.2923, Validation Accuracy: 0.1431


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 96/500, Loss: 2.2358, Accuracy: 0.1685
Validation Loss: 2.2379, Validation Accuracy: 0.1641


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 97/500, Loss: 2.2179, Accuracy: 0.1773
Validation Loss: 2.2339, Validation Accuracy: 0.1753


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 98/500, Loss: 2.2339, Accuracy: 0.1646
Validation Loss: 2.2236, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 99/500, Loss: 2.2907, Accuracy: 0.1488
Validation Loss: 2.2254, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 100/500, Loss: 2.2611, Accuracy: 0.1509
Validation Loss: 2.2193, Validation Accuracy: 0.1683


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 101/500, Loss: 2.2473, Accuracy: 0.1555
Validation Loss: 2.2189, Validation Accuracy: 0.1767


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 102/500, Loss: 2.2049, Accuracy: 0.1822
Validation Loss: 2.2788, Validation Accuracy: 0.1487


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 103/500, Loss: 2.2194, Accuracy: 0.1604
Validation Loss: 2.2191, Validation Accuracy: 0.1823




New best model saved with accuracy: 0.1823


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 104/500, Loss: 2.1930, Accuracy: 0.1836
Validation Loss: 2.2109, Validation Accuracy: 0.1753


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 105/500, Loss: 2.1972, Accuracy: 0.1843
Validation Loss: 2.1912, Validation Accuracy: 0.1823


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 106/500, Loss: 2.1991, Accuracy: 0.1906
Validation Loss: 2.2964, Validation Accuracy: 0.1529


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 107/500, Loss: 2.2091, Accuracy: 0.1825
Validation Loss: 2.2191, Validation Accuracy: 0.1781


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 108/500, Loss: 2.1634, Accuracy: 0.2043
Validation Loss: 2.1857, Validation Accuracy: 0.1964




New best model saved with accuracy: 0.1964


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 109/500, Loss: 2.1695, Accuracy: 0.1973
Validation Loss: 2.1861, Validation Accuracy: 0.1767


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 110/500, Loss: 2.1649, Accuracy: 0.2110
Validation Loss: 2.1810, Validation Accuracy: 0.1964


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 111/500, Loss: 2.1626, Accuracy: 0.2025
Validation Loss: 2.1907, Validation Accuracy: 0.2020




New best model saved with accuracy: 0.2020


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 112/500, Loss: 2.1676, Accuracy: 0.2004
Validation Loss: 2.2172, Validation Accuracy: 0.1823


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 113/500, Loss: 2.1330, Accuracy: 0.2092
Validation Loss: 2.1939, Validation Accuracy: 0.1978


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 114/500, Loss: 2.1568, Accuracy: 0.2180
Validation Loss: 2.1930, Validation Accuracy: 0.2034




New best model saved with accuracy: 0.2034


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 115/500, Loss: 2.1389, Accuracy: 0.2215
Validation Loss: 2.1845, Validation Accuracy: 0.2048




New best model saved with accuracy: 0.2048


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 116/500, Loss: 2.1335, Accuracy: 0.2110
Validation Loss: 2.1933, Validation Accuracy: 0.1893


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 117/500, Loss: 2.1267, Accuracy: 0.2232
Validation Loss: 2.1740, Validation Accuracy: 0.2076




New best model saved with accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 118/500, Loss: 2.1208, Accuracy: 0.2303
Validation Loss: 2.1585, Validation Accuracy: 0.2104




New best model saved with accuracy: 0.2104


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 119/500, Loss: 2.1293, Accuracy: 0.2257
Validation Loss: 2.1926, Validation Accuracy: 0.1921


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 120/500, Loss: 2.1288, Accuracy: 0.2218
Validation Loss: 2.1707, Validation Accuracy: 0.2090


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 121/500, Loss: 2.1208, Accuracy: 0.2296
Validation Loss: 2.1772, Validation Accuracy: 0.2062


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 122/500, Loss: 2.1283, Accuracy: 0.2222
Validation Loss: 2.1741, Validation Accuracy: 0.2090


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 123/500, Loss: 2.1121, Accuracy: 0.2341
Validation Loss: 2.1728, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 124/500, Loss: 2.1321, Accuracy: 0.2292
Validation Loss: 2.1730, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 125/500, Loss: 2.1223, Accuracy: 0.2257
Validation Loss: 2.1730, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 126/500, Loss: 2.1099, Accuracy: 0.2306
Validation Loss: 2.1727, Validation Accuracy: 0.2062


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 127/500, Loss: 2.1285, Accuracy: 0.2246
Validation Loss: 2.1719, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 128/500, Loss: 2.1165, Accuracy: 0.2299
Validation Loss: 2.1652, Validation Accuracy: 0.2118




New best model saved with accuracy: 0.2118


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 129/500, Loss: 2.1145, Accuracy: 0.2310
Validation Loss: 2.1709, Validation Accuracy: 0.2062


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 130/500, Loss: 2.1244, Accuracy: 0.2327
Validation Loss: 2.1702, Validation Accuracy: 0.2062


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 131/500, Loss: 2.1318, Accuracy: 0.2243
Validation Loss: 2.1757, Validation Accuracy: 0.2062


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 132/500, Loss: 2.1142, Accuracy: 0.2253
Validation Loss: 2.1932, Validation Accuracy: 0.2006


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 133/500, Loss: 2.1328, Accuracy: 0.2306
Validation Loss: 2.1684, Validation Accuracy: 0.2132




New best model saved with accuracy: 0.2132


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 134/500, Loss: 2.1194, Accuracy: 0.2274
Validation Loss: 2.1935, Validation Accuracy: 0.1907


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 135/500, Loss: 2.1242, Accuracy: 0.2289
Validation Loss: 2.1828, Validation Accuracy: 0.2300




New best model saved with accuracy: 0.2300


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 136/500, Loss: 2.1000, Accuracy: 0.2394
Validation Loss: 2.2257, Validation Accuracy: 0.1907


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 137/500, Loss: 2.1143, Accuracy: 0.2229
Validation Loss: 2.1756, Validation Accuracy: 0.1992


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 138/500, Loss: 2.1159, Accuracy: 0.2390
Validation Loss: 2.1789, Validation Accuracy: 0.2034


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 139/500, Loss: 2.1041, Accuracy: 0.2387
Validation Loss: 2.1283, Validation Accuracy: 0.2286


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 140/500, Loss: 2.1119, Accuracy: 0.2362
Validation Loss: 2.1710, Validation Accuracy: 0.2160


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 141/500, Loss: 2.0999, Accuracy: 0.2429
Validation Loss: 2.1364, Validation Accuracy: 0.2356




New best model saved with accuracy: 0.2356


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 142/500, Loss: 2.0968, Accuracy: 0.2387
Validation Loss: 2.1730, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 143/500, Loss: 2.1100, Accuracy: 0.2387
Validation Loss: 2.2193, Validation Accuracy: 0.1964


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 144/500, Loss: 2.1051, Accuracy: 0.2380
Validation Loss: 2.1473, Validation Accuracy: 0.2328


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 145/500, Loss: 2.1004, Accuracy: 0.2327
Validation Loss: 2.1290, Validation Accuracy: 0.2314


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 146/500, Loss: 2.0973, Accuracy: 0.2446
Validation Loss: 2.1511, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 147/500, Loss: 2.0945, Accuracy: 0.2499
Validation Loss: 2.1401, Validation Accuracy: 0.2216


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 148/500, Loss: 2.0906, Accuracy: 0.2390
Validation Loss: 2.1780, Validation Accuracy: 0.2020


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 149/500, Loss: 2.0987, Accuracy: 0.2478
Validation Loss: 2.1010, Validation Accuracy: 0.2482




New best model saved with accuracy: 0.2482


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 150/500, Loss: 2.0763, Accuracy: 0.2545
Validation Loss: 2.1371, Validation Accuracy: 0.2412


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 151/500, Loss: 2.0686, Accuracy: 0.2580
Validation Loss: 2.1040, Validation Accuracy: 0.2440


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 152/500, Loss: 2.0788, Accuracy: 0.2555
Validation Loss: 2.1372, Validation Accuracy: 0.2174


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 153/500, Loss: 2.0626, Accuracy: 0.2682
Validation Loss: 2.2075, Validation Accuracy: 0.2006


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 154/500, Loss: 2.0774, Accuracy: 0.2468
Validation Loss: 2.1748, Validation Accuracy: 0.2230


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 155/500, Loss: 2.0819, Accuracy: 0.2527
Validation Loss: 2.1509, Validation Accuracy: 0.2216


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 156/500, Loss: 2.0860, Accuracy: 0.2513
Validation Loss: 2.1975, Validation Accuracy: 0.2202


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 157/500, Loss: 2.0986, Accuracy: 0.2345
Validation Loss: 2.1124, Validation Accuracy: 0.2328


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 158/500, Loss: 2.0770, Accuracy: 0.2425
Validation Loss: 2.4012, Validation Accuracy: 0.1445


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 159/500, Loss: 2.1020, Accuracy: 0.2446
Validation Loss: 2.1439, Validation Accuracy: 0.2468


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 160/500, Loss: 2.0670, Accuracy: 0.2636
Validation Loss: 2.2909, Validation Accuracy: 0.1809


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 161/500, Loss: 2.1050, Accuracy: 0.2562
Validation Loss: 2.1116, Validation Accuracy: 0.2426


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 162/500, Loss: 2.0512, Accuracy: 0.2710
Validation Loss: 2.1155, Validation Accuracy: 0.2468


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 163/500, Loss: 2.0753, Accuracy: 0.2475
Validation Loss: 2.1417, Validation Accuracy: 0.2314


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 164/500, Loss: 2.0140, Accuracy: 0.2840
Validation Loss: 2.2378, Validation Accuracy: 0.2006


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 165/500, Loss: 2.0326, Accuracy: 0.2717
Validation Loss: 2.1371, Validation Accuracy: 0.2160


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 166/500, Loss: 2.0728, Accuracy: 0.2590
Validation Loss: 2.1221, Validation Accuracy: 0.2342


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 167/500, Loss: 2.0553, Accuracy: 0.2573
Validation Loss: 2.1376, Validation Accuracy: 0.2286


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 168/500, Loss: 2.0182, Accuracy: 0.2787
Validation Loss: 2.1013, Validation Accuracy: 0.2440


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 169/500, Loss: 2.0185, Accuracy: 0.2790
Validation Loss: 2.2634, Validation Accuracy: 0.1739


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 170/500, Loss: 2.0238, Accuracy: 0.2840
Validation Loss: 2.1452, Validation Accuracy: 0.2244


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 171/500, Loss: 2.0113, Accuracy: 0.2755
Validation Loss: 2.1297, Validation Accuracy: 0.2412


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 172/500, Loss: 2.0254, Accuracy: 0.2829
Validation Loss: 2.1450, Validation Accuracy: 0.2384


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 173/500, Loss: 2.0146, Accuracy: 0.2857
Validation Loss: 2.1738, Validation Accuracy: 0.2202


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 174/500, Loss: 1.9852, Accuracy: 0.2955
Validation Loss: 2.1322, Validation Accuracy: 0.2412


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 175/500, Loss: 1.9567, Accuracy: 0.3050
Validation Loss: 2.1474, Validation Accuracy: 0.2342


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 176/500, Loss: 1.9707, Accuracy: 0.3127
Validation Loss: 2.0973, Validation Accuracy: 0.2553




New best model saved with accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 177/500, Loss: 1.9941, Accuracy: 0.2969
Validation Loss: 2.1002, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 178/500, Loss: 1.9593, Accuracy: 0.3085
Validation Loss: 2.1092, Validation Accuracy: 0.2398


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 179/500, Loss: 1.9768, Accuracy: 0.3159
Validation Loss: 2.2200, Validation Accuracy: 0.2132


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 180/500, Loss: 2.0389, Accuracy: 0.2780
Validation Loss: 2.1340, Validation Accuracy: 0.2174


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 181/500, Loss: 1.9852, Accuracy: 0.2955
Validation Loss: 2.0979, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 182/500, Loss: 1.9388, Accuracy: 0.3163
Validation Loss: 2.1790, Validation Accuracy: 0.2244


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 183/500, Loss: 1.9685, Accuracy: 0.3124
Validation Loss: 2.0534, Validation Accuracy: 0.2637




New best model saved with accuracy: 0.2637


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 184/500, Loss: 1.9668, Accuracy: 0.3127
Validation Loss: 2.1849, Validation Accuracy: 0.2300


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 185/500, Loss: 1.9262, Accuracy: 0.3184
Validation Loss: 2.1451, Validation Accuracy: 0.2454


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 186/500, Loss: 1.9073, Accuracy: 0.3313
Validation Loss: 2.1449, Validation Accuracy: 0.2412


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 187/500, Loss: 1.9308, Accuracy: 0.3278
Validation Loss: 2.0757, Validation Accuracy: 0.2665




New best model saved with accuracy: 0.2665


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 188/500, Loss: 1.9135, Accuracy: 0.3250
Validation Loss: 2.0681, Validation Accuracy: 0.2665


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 189/500, Loss: 1.9025, Accuracy: 0.3454
Validation Loss: 2.1102, Validation Accuracy: 0.2412


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 190/500, Loss: 1.9111, Accuracy: 0.3303
Validation Loss: 2.0502, Validation Accuracy: 0.2637


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 191/500, Loss: 1.9047, Accuracy: 0.3370
Validation Loss: 2.1295, Validation Accuracy: 0.2440


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 192/500, Loss: 1.8955, Accuracy: 0.3320
Validation Loss: 2.0977, Validation Accuracy: 0.2623


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 193/500, Loss: 1.8961, Accuracy: 0.3440
Validation Loss: 2.0753, Validation Accuracy: 0.2665


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 194/500, Loss: 1.8880, Accuracy: 0.3464
Validation Loss: 2.0796, Validation Accuracy: 0.2482


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 195/500, Loss: 1.8761, Accuracy: 0.3542
Validation Loss: 2.0577, Validation Accuracy: 0.2623


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 196/500, Loss: 1.8688, Accuracy: 0.3482
Validation Loss: 2.0708, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 197/500, Loss: 1.8672, Accuracy: 0.3587
Validation Loss: 2.0531, Validation Accuracy: 0.2637


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 198/500, Loss: 1.8471, Accuracy: 0.3647
Validation Loss: 2.0783, Validation Accuracy: 0.2496


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 199/500, Loss: 1.8568, Accuracy: 0.3531
Validation Loss: 2.0583, Validation Accuracy: 0.2609


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 200/500, Loss: 1.8412, Accuracy: 0.3608
Validation Loss: 2.0621, Validation Accuracy: 0.2595


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 201/500, Loss: 1.8368, Accuracy: 0.3703
Validation Loss: 2.0611, Validation Accuracy: 0.2595


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 202/500, Loss: 1.8623, Accuracy: 0.3605
Validation Loss: 2.0711, Validation Accuracy: 0.2623


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 203/500, Loss: 1.8487, Accuracy: 0.3622
Validation Loss: 2.0609, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 204/500, Loss: 1.8547, Accuracy: 0.3640
Validation Loss: 2.0669, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 205/500, Loss: 1.8591, Accuracy: 0.3559
Validation Loss: 2.0646, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 206/500, Loss: 1.8520, Accuracy: 0.3615
Validation Loss: 2.0631, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 207/500, Loss: 1.8453, Accuracy: 0.3563
Validation Loss: 2.0630, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 208/500, Loss: 1.8445, Accuracy: 0.3728
Validation Loss: 2.0624, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 209/500, Loss: 1.8498, Accuracy: 0.3714
Validation Loss: 2.0637, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.53s/it]


Epoch 210/500, Loss: 1.8528, Accuracy: 0.3573
Validation Loss: 2.0632, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 211/500, Loss: 1.8483, Accuracy: 0.3626
Validation Loss: 2.0656, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 212/500, Loss: 1.8428, Accuracy: 0.3647
Validation Loss: 2.0627, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 213/500, Loss: 1.8436, Accuracy: 0.3766
Validation Loss: 2.0532, Validation Accuracy: 0.2651


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 214/500, Loss: 1.8533, Accuracy: 0.3622
Validation Loss: 2.0561, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 215/500, Loss: 1.8233, Accuracy: 0.3731
Validation Loss: 2.0640, Validation Accuracy: 0.2595


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 216/500, Loss: 1.8629, Accuracy: 0.3650
Validation Loss: 2.0684, Validation Accuracy: 0.2637


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 217/500, Loss: 1.8500, Accuracy: 0.3556
Validation Loss: 2.0851, Validation Accuracy: 0.2609


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 218/500, Loss: 1.8479, Accuracy: 0.3724
Validation Loss: 2.0544, Validation Accuracy: 0.2707




New best model saved with accuracy: 0.2707


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 219/500, Loss: 1.8596, Accuracy: 0.3601
Validation Loss: 2.1106, Validation Accuracy: 0.2595


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 220/500, Loss: 1.8711, Accuracy: 0.3471
Validation Loss: 2.0749, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 221/500, Loss: 1.8434, Accuracy: 0.3640
Validation Loss: 2.0154, Validation Accuracy: 0.2847




New best model saved with accuracy: 0.2847


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 222/500, Loss: 1.8607, Accuracy: 0.3570
Validation Loss: 2.0780, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 223/500, Loss: 1.8683, Accuracy: 0.3471
Validation Loss: 2.0430, Validation Accuracy: 0.2637


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 224/500, Loss: 1.8512, Accuracy: 0.3703
Validation Loss: 2.0687, Validation Accuracy: 0.2553


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 225/500, Loss: 1.8319, Accuracy: 0.3714
Validation Loss: 2.0880, Validation Accuracy: 0.2567


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 226/500, Loss: 1.8500, Accuracy: 0.3668
Validation Loss: 2.0835, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 227/500, Loss: 1.8429, Accuracy: 0.3556
Validation Loss: 2.0890, Validation Accuracy: 0.2651


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 228/500, Loss: 1.8706, Accuracy: 0.3478
Validation Loss: 2.0771, Validation Accuracy: 0.2707


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 229/500, Loss: 1.8477, Accuracy: 0.3643
Validation Loss: 2.1122, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 230/500, Loss: 1.8680, Accuracy: 0.3563
Validation Loss: 2.0866, Validation Accuracy: 0.2525


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 231/500, Loss: 1.8396, Accuracy: 0.3570
Validation Loss: 2.0960, Validation Accuracy: 0.2679


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 232/500, Loss: 1.8301, Accuracy: 0.3549
Validation Loss: 1.9789, Validation Accuracy: 0.2945




New best model saved with accuracy: 0.2945


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 233/500, Loss: 1.8690, Accuracy: 0.3521
Validation Loss: 2.0896, Validation Accuracy: 0.2609


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 234/500, Loss: 1.8365, Accuracy: 0.3580
Validation Loss: 2.0412, Validation Accuracy: 0.2651


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 235/500, Loss: 1.8190, Accuracy: 0.3707
Validation Loss: 2.0373, Validation Accuracy: 0.2819


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 236/500, Loss: 1.8465, Accuracy: 0.3650
Validation Loss: 2.1329, Validation Accuracy: 0.2496


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 237/500, Loss: 1.8373, Accuracy: 0.3759
Validation Loss: 1.9963, Validation Accuracy: 0.2959




New best model saved with accuracy: 0.2959


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 238/500, Loss: 1.8397, Accuracy: 0.3735
Validation Loss: 1.9847, Validation Accuracy: 0.2973




New best model saved with accuracy: 0.2973


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 239/500, Loss: 1.8341, Accuracy: 0.3598
Validation Loss: 2.0575, Validation Accuracy: 0.2819


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 240/500, Loss: 1.8158, Accuracy: 0.3822
Validation Loss: 2.0662, Validation Accuracy: 0.2735


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 241/500, Loss: 1.7976, Accuracy: 0.3836
Validation Loss: 2.0320, Validation Accuracy: 0.2833


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 242/500, Loss: 1.7820, Accuracy: 0.3854
Validation Loss: 2.0448, Validation Accuracy: 0.2847


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 243/500, Loss: 1.8154, Accuracy: 0.3872
Validation Loss: 2.0733, Validation Accuracy: 0.2917


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 244/500, Loss: 1.7930, Accuracy: 0.3850
Validation Loss: 2.0386, Validation Accuracy: 0.2805


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 245/500, Loss: 1.8272, Accuracy: 0.3707
Validation Loss: 2.1012, Validation Accuracy: 0.2665


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 246/500, Loss: 1.7884, Accuracy: 0.3836
Validation Loss: 2.0295, Validation Accuracy: 0.2833


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 247/500, Loss: 1.8010, Accuracy: 0.3735
Validation Loss: 2.0723, Validation Accuracy: 0.2693


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 248/500, Loss: 1.8300, Accuracy: 0.3721
Validation Loss: 2.0865, Validation Accuracy: 0.2735


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 249/500, Loss: 1.8074, Accuracy: 0.3857
Validation Loss: 1.9910, Validation Accuracy: 0.2987




New best model saved with accuracy: 0.2987


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 250/500, Loss: 1.8147, Accuracy: 0.3664
Validation Loss: 2.0942, Validation Accuracy: 0.2735


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 251/500, Loss: 1.7924, Accuracy: 0.3696
Validation Loss: 2.0014, Validation Accuracy: 0.2959


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 252/500, Loss: 1.7635, Accuracy: 0.4037
Validation Loss: 2.1096, Validation Accuracy: 0.2875


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 253/500, Loss: 1.8050, Accuracy: 0.3696
Validation Loss: 2.0385, Validation Accuracy: 0.2763


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 254/500, Loss: 1.7982, Accuracy: 0.3770
Validation Loss: 2.1104, Validation Accuracy: 0.2777


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 255/500, Loss: 1.7993, Accuracy: 0.3822
Validation Loss: 2.0836, Validation Accuracy: 0.2749


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 256/500, Loss: 1.7656, Accuracy: 0.3914
Validation Loss: 2.0192, Validation Accuracy: 0.2889


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 257/500, Loss: 1.7483, Accuracy: 0.3973
Validation Loss: 2.0320, Validation Accuracy: 0.2931


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 258/500, Loss: 1.7075, Accuracy: 0.4194
Validation Loss: 1.9797, Validation Accuracy: 0.3100




New best model saved with accuracy: 0.3100


100%|██████████| 12/12 [00:55<00:00,  4.62s/it]


Epoch 259/500, Loss: 1.7515, Accuracy: 0.4128
Validation Loss: 1.9584, Validation Accuracy: 0.3170




New best model saved with accuracy: 0.3170


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 260/500, Loss: 1.7499, Accuracy: 0.4029
Validation Loss: 2.0095, Validation Accuracy: 0.2917


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 261/500, Loss: 1.7223, Accuracy: 0.4065
Validation Loss: 2.0005, Validation Accuracy: 0.3072


100%|██████████| 12/12 [00:54<00:00,  4.52s/it]


Epoch 262/500, Loss: 1.7090, Accuracy: 0.4061
Validation Loss: 2.0227, Validation Accuracy: 0.3015


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 263/500, Loss: 1.7159, Accuracy: 0.4156
Validation Loss: 1.9924, Validation Accuracy: 0.3128


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 264/500, Loss: 1.6967, Accuracy: 0.4226
Validation Loss: 1.9586, Validation Accuracy: 0.3156


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 265/500, Loss: 1.7371, Accuracy: 0.4096
Validation Loss: 2.0090, Validation Accuracy: 0.3001


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 266/500, Loss: 1.7025, Accuracy: 0.4258
Validation Loss: 2.0084, Validation Accuracy: 0.2987


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 267/500, Loss: 1.6733, Accuracy: 0.4310
Validation Loss: 2.0092, Validation Accuracy: 0.2987


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 268/500, Loss: 1.6532, Accuracy: 0.4454
Validation Loss: 1.9717, Validation Accuracy: 0.3226




New best model saved with accuracy: 0.3226


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 269/500, Loss: 1.6582, Accuracy: 0.4412
Validation Loss: 1.9851, Validation Accuracy: 0.3114


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 270/500, Loss: 1.6578, Accuracy: 0.4472
Validation Loss: 1.9854, Validation Accuracy: 0.3156


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 271/500, Loss: 1.6221, Accuracy: 0.4531
Validation Loss: 1.9854, Validation Accuracy: 0.3100


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 272/500, Loss: 1.6220, Accuracy: 0.4479
Validation Loss: 1.9764, Validation Accuracy: 0.3142


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 273/500, Loss: 1.6501, Accuracy: 0.4338
Validation Loss: 1.9891, Validation Accuracy: 0.3086


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 274/500, Loss: 1.6282, Accuracy: 0.4567
Validation Loss: 2.0008, Validation Accuracy: 0.3184


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 275/500, Loss: 1.5990, Accuracy: 0.4598
Validation Loss: 1.9693, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 276/500, Loss: 1.6074, Accuracy: 0.4665
Validation Loss: 1.9727, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 277/500, Loss: 1.6135, Accuracy: 0.4581
Validation Loss: 1.9832, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 278/500, Loss: 1.5967, Accuracy: 0.4616
Validation Loss: 1.9850, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 279/500, Loss: 1.6242, Accuracy: 0.4545
Validation Loss: 1.9754, Validation Accuracy: 0.3128


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 280/500, Loss: 1.5703, Accuracy: 0.4858
Validation Loss: 1.9582, Validation Accuracy: 0.3170


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 281/500, Loss: 1.5915, Accuracy: 0.4742
Validation Loss: 1.9558, Validation Accuracy: 0.3282




New best model saved with accuracy: 0.3282


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 282/500, Loss: 1.5612, Accuracy: 0.4795
Validation Loss: 1.9780, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 283/500, Loss: 1.5944, Accuracy: 0.4746
Validation Loss: 1.9673, Validation Accuracy: 0.3296




New best model saved with accuracy: 0.3296


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 284/500, Loss: 1.5824, Accuracy: 0.4823
Validation Loss: 1.9520, Validation Accuracy: 0.3296


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 285/500, Loss: 1.5717, Accuracy: 0.4777
Validation Loss: 1.9613, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 286/500, Loss: 1.5686, Accuracy: 0.4756
Validation Loss: 1.9641, Validation Accuracy: 0.3254


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 287/500, Loss: 1.5654, Accuracy: 0.4724
Validation Loss: 1.9533, Validation Accuracy: 0.3352




New best model saved with accuracy: 0.3352


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 288/500, Loss: 1.5612, Accuracy: 0.4903
Validation Loss: 1.9546, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 289/500, Loss: 1.5451, Accuracy: 0.4753
Validation Loss: 1.9559, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 290/500, Loss: 1.5482, Accuracy: 0.4910
Validation Loss: 1.9561, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 291/500, Loss: 1.5643, Accuracy: 0.4854
Validation Loss: 1.9579, Validation Accuracy: 0.3338


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 292/500, Loss: 1.5601, Accuracy: 0.4739
Validation Loss: 1.9584, Validation Accuracy: 0.3324


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 293/500, Loss: 1.5571, Accuracy: 0.4900
Validation Loss: 1.9596, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 294/500, Loss: 1.5568, Accuracy: 0.4756
Validation Loss: 1.9569, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 295/500, Loss: 1.5691, Accuracy: 0.4777
Validation Loss: 1.9616, Validation Accuracy: 0.3282


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 296/500, Loss: 1.5852, Accuracy: 0.4918
Validation Loss: 1.9634, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 297/500, Loss: 1.5762, Accuracy: 0.4826
Validation Loss: 1.9652, Validation Accuracy: 0.3282


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 298/500, Loss: 1.5619, Accuracy: 0.4819
Validation Loss: 1.9728, Validation Accuracy: 0.3240


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 299/500, Loss: 1.5435, Accuracy: 0.4844
Validation Loss: 1.9581, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 300/500, Loss: 1.5728, Accuracy: 0.4731
Validation Loss: 1.9873, Validation Accuracy: 0.3114


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 301/500, Loss: 1.5533, Accuracy: 0.4826
Validation Loss: 1.9775, Validation Accuracy: 0.3184


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 302/500, Loss: 1.5917, Accuracy: 0.4816
Validation Loss: 1.9843, Validation Accuracy: 0.3184


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 303/500, Loss: 1.5943, Accuracy: 0.4591
Validation Loss: 1.9796, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 304/500, Loss: 1.6071, Accuracy: 0.4819
Validation Loss: 1.9911, Validation Accuracy: 0.3114


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 305/500, Loss: 1.5629, Accuracy: 0.4710
Validation Loss: 1.9629, Validation Accuracy: 0.3240


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 306/500, Loss: 1.5855, Accuracy: 0.4774
Validation Loss: 1.9670, Validation Accuracy: 0.3296


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 307/500, Loss: 1.5591, Accuracy: 0.4714
Validation Loss: 1.9307, Validation Accuracy: 0.3352


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 308/500, Loss: 1.6090, Accuracy: 0.4514
Validation Loss: 1.9867, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 309/500, Loss: 1.6088, Accuracy: 0.4707
Validation Loss: 1.9900, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 310/500, Loss: 1.5977, Accuracy: 0.4672
Validation Loss: 1.9027, Validation Accuracy: 0.3633




New best model saved with accuracy: 0.3633


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 311/500, Loss: 1.5911, Accuracy: 0.4612
Validation Loss: 1.9826, Validation Accuracy: 0.3226


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 312/500, Loss: 1.5803, Accuracy: 0.4619
Validation Loss: 1.9528, Validation Accuracy: 0.3366


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 313/500, Loss: 1.5782, Accuracy: 0.4619
Validation Loss: 1.9549, Validation Accuracy: 0.3366


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 314/500, Loss: 1.6128, Accuracy: 0.4349
Validation Loss: 1.9615, Validation Accuracy: 0.3324


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 315/500, Loss: 1.5992, Accuracy: 0.4577
Validation Loss: 2.0148, Validation Accuracy: 0.3142


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 316/500, Loss: 1.5812, Accuracy: 0.4619
Validation Loss: 2.1306, Validation Accuracy: 0.2847


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 317/500, Loss: 1.6959, Accuracy: 0.4138
Validation Loss: 2.0415, Validation Accuracy: 0.3142


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 318/500, Loss: 1.5961, Accuracy: 0.4633
Validation Loss: 2.0060, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 319/500, Loss: 1.5917, Accuracy: 0.4595
Validation Loss: 1.9644, Validation Accuracy: 0.3324


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 320/500, Loss: 1.5735, Accuracy: 0.4630
Validation Loss: 2.0149, Validation Accuracy: 0.3156


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 321/500, Loss: 1.6164, Accuracy: 0.4633
Validation Loss: 2.0007, Validation Accuracy: 0.3114


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 322/500, Loss: 1.6028, Accuracy: 0.4395
Validation Loss: 2.0287, Validation Accuracy: 0.3072


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 323/500, Loss: 1.5931, Accuracy: 0.4602
Validation Loss: 2.0483, Validation Accuracy: 0.3114


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 324/500, Loss: 1.5862, Accuracy: 0.4693
Validation Loss: 1.9492, Validation Accuracy: 0.3478


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 325/500, Loss: 1.5417, Accuracy: 0.4795
Validation Loss: 2.0586, Validation Accuracy: 0.3184


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 326/500, Loss: 1.5985, Accuracy: 0.4542
Validation Loss: 2.0628, Validation Accuracy: 0.2959


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 327/500, Loss: 1.6062, Accuracy: 0.4468
Validation Loss: 1.9370, Validation Accuracy: 0.3548


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 328/500, Loss: 1.5742, Accuracy: 0.4710
Validation Loss: 2.0272, Validation Accuracy: 0.3240


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 329/500, Loss: 1.6247, Accuracy: 0.4430
Validation Loss: 2.0576, Validation Accuracy: 0.3254


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 330/500, Loss: 1.5615, Accuracy: 0.4647
Validation Loss: 2.0566, Validation Accuracy: 0.3128


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 331/500, Loss: 1.5492, Accuracy: 0.4795
Validation Loss: 1.9579, Validation Accuracy: 0.3436


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 332/500, Loss: 1.5359, Accuracy: 0.4731
Validation Loss: 1.9886, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 333/500, Loss: 1.5008, Accuracy: 0.4925
Validation Loss: 1.9894, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 334/500, Loss: 1.5650, Accuracy: 0.4696
Validation Loss: 2.0237, Validation Accuracy: 0.3352


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 335/500, Loss: 1.5291, Accuracy: 0.4882
Validation Loss: 2.0629, Validation Accuracy: 0.3226


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 336/500, Loss: 1.5042, Accuracy: 0.4865
Validation Loss: 2.0214, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 337/500, Loss: 1.4812, Accuracy: 0.5047
Validation Loss: 1.9622, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 338/500, Loss: 1.5305, Accuracy: 0.4889
Validation Loss: 2.0703, Validation Accuracy: 0.3142


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 339/500, Loss: 1.5224, Accuracy: 0.4826
Validation Loss: 1.9189, Validation Accuracy: 0.3492


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 340/500, Loss: 1.4965, Accuracy: 0.4942
Validation Loss: 1.9667, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 341/500, Loss: 1.4858, Accuracy: 0.4914
Validation Loss: 2.0265, Validation Accuracy: 0.3394


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 342/500, Loss: 1.5027, Accuracy: 0.4967
Validation Loss: 2.0703, Validation Accuracy: 0.3296


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 343/500, Loss: 1.4910, Accuracy: 0.4942
Validation Loss: 1.9640, Validation Accuracy: 0.3380


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 344/500, Loss: 1.5086, Accuracy: 0.4823
Validation Loss: 2.0215, Validation Accuracy: 0.3338


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 345/500, Loss: 1.4495, Accuracy: 0.5082
Validation Loss: 1.9544, Validation Accuracy: 0.3436


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 346/500, Loss: 1.4702, Accuracy: 0.5033
Validation Loss: 2.0698, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 347/500, Loss: 1.4738, Accuracy: 0.5040
Validation Loss: 2.0599, Validation Accuracy: 0.3156


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 348/500, Loss: 1.4719, Accuracy: 0.5047
Validation Loss: 1.9182, Validation Accuracy: 0.3548


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 349/500, Loss: 1.4379, Accuracy: 0.5156
Validation Loss: 1.9668, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 350/500, Loss: 1.4217, Accuracy: 0.5202
Validation Loss: 1.9762, Validation Accuracy: 0.3450


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 351/500, Loss: 1.3839, Accuracy: 0.5349
Validation Loss: 1.9380, Validation Accuracy: 0.3604


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 352/500, Loss: 1.4212, Accuracy: 0.5349
Validation Loss: 1.9688, Validation Accuracy: 0.3422


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 353/500, Loss: 1.4337, Accuracy: 0.5121
Validation Loss: 1.9439, Validation Accuracy: 0.3675




New best model saved with accuracy: 0.3675


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 354/500, Loss: 1.3571, Accuracy: 0.5476
Validation Loss: 1.9438, Validation Accuracy: 0.3506


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 355/500, Loss: 1.3630, Accuracy: 0.5412
Validation Loss: 1.9668, Validation Accuracy: 0.3380


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 356/500, Loss: 1.3674, Accuracy: 0.5444
Validation Loss: 1.9688, Validation Accuracy: 0.3450


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 357/500, Loss: 1.3837, Accuracy: 0.5360
Validation Loss: 1.9466, Validation Accuracy: 0.3408


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 358/500, Loss: 1.3559, Accuracy: 0.5479
Validation Loss: 1.9011, Validation Accuracy: 0.3619


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 359/500, Loss: 1.3368, Accuracy: 0.5556
Validation Loss: 1.9852, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 360/500, Loss: 1.3729, Accuracy: 0.5556
Validation Loss: 1.9603, Validation Accuracy: 0.3436


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 361/500, Loss: 1.3594, Accuracy: 0.5574
Validation Loss: 1.8908, Validation Accuracy: 0.3604


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 362/500, Loss: 1.3568, Accuracy: 0.5570
Validation Loss: 1.9017, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:55<00:00,  4.63s/it]


Epoch 363/500, Loss: 1.3218, Accuracy: 0.5630
Validation Loss: 1.9261, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 364/500, Loss: 1.3042, Accuracy: 0.5788
Validation Loss: 1.9323, Validation Accuracy: 0.3492


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 365/500, Loss: 1.3322, Accuracy: 0.5560
Validation Loss: 1.9133, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 366/500, Loss: 1.2954, Accuracy: 0.5651
Validation Loss: 1.9365, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 367/500, Loss: 1.3438, Accuracy: 0.5630
Validation Loss: 1.9027, Validation Accuracy: 0.3590


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 368/500, Loss: 1.3004, Accuracy: 0.5658
Validation Loss: 1.9221, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 369/500, Loss: 1.3109, Accuracy: 0.5721
Validation Loss: 1.9154, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 370/500, Loss: 1.2986, Accuracy: 0.5802
Validation Loss: 1.9130, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 371/500, Loss: 1.3131, Accuracy: 0.5784
Validation Loss: 1.9144, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 372/500, Loss: 1.2808, Accuracy: 0.5792
Validation Loss: 1.9143, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 373/500, Loss: 1.3148, Accuracy: 0.5760
Validation Loss: 1.9139, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 374/500, Loss: 1.3094, Accuracy: 0.5676
Validation Loss: 1.9170, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 375/500, Loss: 1.3190, Accuracy: 0.5658
Validation Loss: 1.9118, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 376/500, Loss: 1.2929, Accuracy: 0.5774
Validation Loss: 1.9229, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 377/500, Loss: 1.2998, Accuracy: 0.5704
Validation Loss: 1.9111, Validation Accuracy: 0.3604


100%|██████████| 12/12 [00:55<00:00,  4.63s/it]


Epoch 378/500, Loss: 1.3054, Accuracy: 0.5749
Validation Loss: 1.9100, Validation Accuracy: 0.3548


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 379/500, Loss: 1.2968, Accuracy: 0.5644
Validation Loss: 1.9126, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 380/500, Loss: 1.3041, Accuracy: 0.5676
Validation Loss: 1.9097, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 381/500, Loss: 1.3077, Accuracy: 0.5602
Validation Loss: 1.9253, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 382/500, Loss: 1.2986, Accuracy: 0.5655
Validation Loss: 1.9098, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 383/500, Loss: 1.3406, Accuracy: 0.5588
Validation Loss: 1.9086, Validation Accuracy: 0.3633


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 384/500, Loss: 1.3503, Accuracy: 0.5490
Validation Loss: 1.8976, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 385/500, Loss: 1.3105, Accuracy: 0.5637
Validation Loss: 1.9116, Validation Accuracy: 0.3661


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 386/500, Loss: 1.3247, Accuracy: 0.5714
Validation Loss: 1.9431, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 387/500, Loss: 1.2946, Accuracy: 0.5781
Validation Loss: 1.9162, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 388/500, Loss: 1.2938, Accuracy: 0.5623
Validation Loss: 1.9709, Validation Accuracy: 0.3590


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 389/500, Loss: 1.3022, Accuracy: 0.5574
Validation Loss: 1.8942, Validation Accuracy: 0.3731




New best model saved with accuracy: 0.3731


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 390/500, Loss: 1.3574, Accuracy: 0.5612
Validation Loss: 1.9779, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 391/500, Loss: 1.3673, Accuracy: 0.5444
Validation Loss: 1.9384, Validation Accuracy: 0.3576


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 392/500, Loss: 1.3211, Accuracy: 0.5669
Validation Loss: 1.9445, Validation Accuracy: 0.3534


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 393/500, Loss: 1.3332, Accuracy: 0.5595
Validation Loss: 1.9356, Validation Accuracy: 0.3590


100%|██████████| 12/12 [00:55<00:00,  4.62s/it]


Epoch 394/500, Loss: 1.3490, Accuracy: 0.5535
Validation Loss: 2.0224, Validation Accuracy: 0.3478


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 395/500, Loss: 1.3401, Accuracy: 0.5409
Validation Loss: 1.8967, Validation Accuracy: 0.3647


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 396/500, Loss: 1.3060, Accuracy: 0.5598
Validation Loss: 1.9224, Validation Accuracy: 0.3703


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 397/500, Loss: 1.3017, Accuracy: 0.5707
Validation Loss: 1.8886, Validation Accuracy: 0.3801




New best model saved with accuracy: 0.3801


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 398/500, Loss: 1.3238, Accuracy: 0.5570
Validation Loss: 1.9739, Validation Accuracy: 0.3450


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 399/500, Loss: 1.3153, Accuracy: 0.5595
Validation Loss: 1.9642, Validation Accuracy: 0.3604


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 400/500, Loss: 1.3162, Accuracy: 0.5648
Validation Loss: 1.9587, Validation Accuracy: 0.3408


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 401/500, Loss: 1.3431, Accuracy: 0.5641
Validation Loss: 1.9745, Validation Accuracy: 0.3717


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 402/500, Loss: 1.3517, Accuracy: 0.5497
Validation Loss: 1.9599, Validation Accuracy: 0.3492


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 403/500, Loss: 1.3410, Accuracy: 0.5504
Validation Loss: 1.9755, Validation Accuracy: 0.3619


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 404/500, Loss: 1.3453, Accuracy: 0.5539
Validation Loss: 1.9829, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 405/500, Loss: 1.3486, Accuracy: 0.5398
Validation Loss: 2.0302, Validation Accuracy: 0.3590


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 406/500, Loss: 1.3996, Accuracy: 0.5125
Validation Loss: 1.9794, Validation Accuracy: 0.3464


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 407/500, Loss: 1.3764, Accuracy: 0.5321
Validation Loss: 2.0471, Validation Accuracy: 0.3422


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 408/500, Loss: 1.3895, Accuracy: 0.5381
Validation Loss: 1.9948, Validation Accuracy: 0.3450


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 409/500, Loss: 1.4173, Accuracy: 0.5125
Validation Loss: 1.9679, Validation Accuracy: 0.3647


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 410/500, Loss: 1.3060, Accuracy: 0.5612
Validation Loss: 1.9282, Validation Accuracy: 0.3717


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 411/500, Loss: 1.3532, Accuracy: 0.5416
Validation Loss: 1.9725, Validation Accuracy: 0.3619


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 412/500, Loss: 1.3451, Accuracy: 0.5405
Validation Loss: 2.0778, Validation Accuracy: 0.3408


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 413/500, Loss: 1.3931, Accuracy: 0.5360
Validation Loss: 2.1038, Validation Accuracy: 0.3240


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 414/500, Loss: 1.3086, Accuracy: 0.5448
Validation Loss: 1.9697, Validation Accuracy: 0.3745


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 415/500, Loss: 1.3150, Accuracy: 0.5539
Validation Loss: 2.0120, Validation Accuracy: 0.3548


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 416/500, Loss: 1.3232, Accuracy: 0.5448
Validation Loss: 2.0065, Validation Accuracy: 0.3633


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 417/500, Loss: 1.3088, Accuracy: 0.5665
Validation Loss: 2.0664, Validation Accuracy: 0.3548


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 418/500, Loss: 1.3458, Accuracy: 0.5416
Validation Loss: 1.9792, Validation Accuracy: 0.3675


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 419/500, Loss: 1.3305, Accuracy: 0.5570
Validation Loss: 2.0153, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 420/500, Loss: 1.3187, Accuracy: 0.5567
Validation Loss: 1.9752, Validation Accuracy: 0.3689


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 421/500, Loss: 1.2836, Accuracy: 0.5683
Validation Loss: 1.9158, Validation Accuracy: 0.3759


100%|██████████| 12/12 [00:55<00:00,  4.62s/it]


Epoch 422/500, Loss: 1.3257, Accuracy: 0.5539
Validation Loss: 1.9904, Validation Accuracy: 0.3717


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 423/500, Loss: 1.2713, Accuracy: 0.5742
Validation Loss: 1.9944, Validation Accuracy: 0.3520


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 424/500, Loss: 1.2475, Accuracy: 0.5802
Validation Loss: 1.8945, Validation Accuracy: 0.3899




New best model saved with accuracy: 0.3899


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 425/500, Loss: 1.2334, Accuracy: 0.5858
Validation Loss: 1.9712, Validation Accuracy: 0.3689


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 426/500, Loss: 1.2282, Accuracy: 0.5858
Validation Loss: 1.8669, Validation Accuracy: 0.4095




New best model saved with accuracy: 0.4095


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 427/500, Loss: 1.2817, Accuracy: 0.5806
Validation Loss: 2.0091, Validation Accuracy: 0.3703


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 428/500, Loss: 1.2424, Accuracy: 0.5886
Validation Loss: 2.0098, Validation Accuracy: 0.3619


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 429/500, Loss: 1.2309, Accuracy: 0.5830
Validation Loss: 1.9918, Validation Accuracy: 0.3661


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 430/500, Loss: 1.2133, Accuracy: 0.5918
Validation Loss: 1.9719, Validation Accuracy: 0.3675


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 431/500, Loss: 1.2042, Accuracy: 0.6027
Validation Loss: 1.8854, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 432/500, Loss: 1.2406, Accuracy: 0.5890
Validation Loss: 2.0251, Validation Accuracy: 0.3703


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 433/500, Loss: 1.1951, Accuracy: 0.6146
Validation Loss: 1.9611, Validation Accuracy: 0.3731


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 434/500, Loss: 1.1863, Accuracy: 0.6076
Validation Loss: 1.9207, Validation Accuracy: 0.4025


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 435/500, Loss: 1.1751, Accuracy: 0.6083
Validation Loss: 1.9321, Validation Accuracy: 0.3955


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 436/500, Loss: 1.1557, Accuracy: 0.6107
Validation Loss: 1.9152, Validation Accuracy: 0.3899


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 437/500, Loss: 1.1610, Accuracy: 0.6150
Validation Loss: 1.9343, Validation Accuracy: 0.3969


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 438/500, Loss: 1.1457, Accuracy: 0.6135
Validation Loss: 1.9611, Validation Accuracy: 0.3885


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 439/500, Loss: 1.1428, Accuracy: 0.6107
Validation Loss: 1.9140, Validation Accuracy: 0.4165




New best model saved with accuracy: 0.4165


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 440/500, Loss: 1.1524, Accuracy: 0.6164
Validation Loss: 1.9323, Validation Accuracy: 0.3983


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 441/500, Loss: 1.1478, Accuracy: 0.6216
Validation Loss: 1.9425, Validation Accuracy: 0.3703


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 442/500, Loss: 1.1362, Accuracy: 0.6195
Validation Loss: 1.9217, Validation Accuracy: 0.4067


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 443/500, Loss: 1.1122, Accuracy: 0.6241
Validation Loss: 1.9281, Validation Accuracy: 0.3941


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 444/500, Loss: 1.1037, Accuracy: 0.6367
Validation Loss: 1.9080, Validation Accuracy: 0.3955


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 445/500, Loss: 1.1204, Accuracy: 0.6357
Validation Loss: 1.8917, Validation Accuracy: 0.4025


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 446/500, Loss: 1.0636, Accuracy: 0.6455
Validation Loss: 1.9013, Validation Accuracy: 0.4039


100%|██████████| 12/12 [00:55<00:00,  4.63s/it]


Epoch 447/500, Loss: 1.0727, Accuracy: 0.6444
Validation Loss: 1.9225, Validation Accuracy: 0.4053


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 448/500, Loss: 1.0932, Accuracy: 0.6395
Validation Loss: 1.8811, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 449/500, Loss: 1.0774, Accuracy: 0.6553
Validation Loss: 1.9205, Validation Accuracy: 0.3941


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 450/500, Loss: 1.0736, Accuracy: 0.6504
Validation Loss: 1.8952, Validation Accuracy: 0.4067


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 451/500, Loss: 1.0358, Accuracy: 0.6511
Validation Loss: 1.9029, Validation Accuracy: 0.4025


100%|██████████| 12/12 [00:55<00:00,  4.62s/it]


Epoch 452/500, Loss: 1.0633, Accuracy: 0.6567
Validation Loss: 1.8924, Validation Accuracy: 0.4011


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 453/500, Loss: 1.0768, Accuracy: 0.6472
Validation Loss: 1.8944, Validation Accuracy: 0.3983


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 454/500, Loss: 1.0745, Accuracy: 0.6595
Validation Loss: 1.8943, Validation Accuracy: 0.3983


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 455/500, Loss: 1.0836, Accuracy: 0.6578
Validation Loss: 1.8945, Validation Accuracy: 0.3983


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 456/500, Loss: 1.0634, Accuracy: 0.6479
Validation Loss: 1.8950, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 457/500, Loss: 1.0678, Accuracy: 0.6455
Validation Loss: 1.8965, Validation Accuracy: 0.4039


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 458/500, Loss: 1.0960, Accuracy: 0.6585
Validation Loss: 1.9036, Validation Accuracy: 0.4011


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 459/500, Loss: 1.0705, Accuracy: 0.6609
Validation Loss: 1.8904, Validation Accuracy: 0.4067


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 460/500, Loss: 1.0576, Accuracy: 0.6515
Validation Loss: 1.8850, Validation Accuracy: 0.4053


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 461/500, Loss: 1.0443, Accuracy: 0.6515
Validation Loss: 1.8938, Validation Accuracy: 0.4081


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 462/500, Loss: 1.0578, Accuracy: 0.6423
Validation Loss: 1.8936, Validation Accuracy: 0.4109


100%|██████████| 12/12 [00:55<00:00,  4.62s/it]


Epoch 463/500, Loss: 1.0644, Accuracy: 0.6518
Validation Loss: 1.8733, Validation Accuracy: 0.4151


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 464/500, Loss: 1.0549, Accuracy: 0.6511
Validation Loss: 1.8777, Validation Accuracy: 0.4039


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 465/500, Loss: 1.1234, Accuracy: 0.6413
Validation Loss: 1.9068, Validation Accuracy: 0.3927


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 466/500, Loss: 1.0677, Accuracy: 0.6543
Validation Loss: 1.9077, Validation Accuracy: 0.3913


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 467/500, Loss: 1.0944, Accuracy: 0.6458
Validation Loss: 1.8976, Validation Accuracy: 0.4123


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 468/500, Loss: 1.0831, Accuracy: 0.6346
Validation Loss: 1.9258, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 469/500, Loss: 1.0739, Accuracy: 0.6511
Validation Loss: 1.9790, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 470/500, Loss: 1.0519, Accuracy: 0.6504
Validation Loss: 1.9263, Validation Accuracy: 0.4081


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 471/500, Loss: 1.1115, Accuracy: 0.6525
Validation Loss: 1.9869, Validation Accuracy: 0.3885


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 472/500, Loss: 1.1363, Accuracy: 0.6206
Validation Loss: 1.9591, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 473/500, Loss: 1.0878, Accuracy: 0.6494
Validation Loss: 1.9202, Validation Accuracy: 0.4109


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 474/500, Loss: 1.0802, Accuracy: 0.6620
Validation Loss: 1.9237, Validation Accuracy: 0.4151


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 475/500, Loss: 1.1158, Accuracy: 0.6181
Validation Loss: 1.9696, Validation Accuracy: 0.3717


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 476/500, Loss: 1.1054, Accuracy: 0.6297
Validation Loss: 2.0761, Validation Accuracy: 0.3675


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 477/500, Loss: 1.1669, Accuracy: 0.6206
Validation Loss: 1.9255, Validation Accuracy: 0.4137


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 478/500, Loss: 1.1031, Accuracy: 0.6195
Validation Loss: 1.9686, Validation Accuracy: 0.4011


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 479/500, Loss: 1.1083, Accuracy: 0.6241
Validation Loss: 1.9411, Validation Accuracy: 0.3927


100%|██████████| 12/12 [00:55<00:00,  4.61s/it]


Epoch 480/500, Loss: 1.1508, Accuracy: 0.6125
Validation Loss: 1.9809, Validation Accuracy: 0.3913


100%|██████████| 12/12 [00:54<00:00,  4.54s/it]


Epoch 481/500, Loss: 1.1073, Accuracy: 0.6300
Validation Loss: 1.9122, Validation Accuracy: 0.3857


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 482/500, Loss: 1.1135, Accuracy: 0.6206
Validation Loss: 1.9356, Validation Accuracy: 0.4011


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 483/500, Loss: 1.1368, Accuracy: 0.6223
Validation Loss: 1.9980, Validation Accuracy: 0.3997


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 484/500, Loss: 1.1184, Accuracy: 0.6381
Validation Loss: 1.9776, Validation Accuracy: 0.3927


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 485/500, Loss: 1.1403, Accuracy: 0.6216
Validation Loss: 1.9122, Validation Accuracy: 0.4067


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 486/500, Loss: 1.1620, Accuracy: 0.6195
Validation Loss: 1.9193, Validation Accuracy: 0.4067


100%|██████████| 12/12 [00:55<00:00,  4.58s/it]


Epoch 487/500, Loss: 1.1426, Accuracy: 0.6244
Validation Loss: 1.9416, Validation Accuracy: 0.3885


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 488/500, Loss: 1.2074, Accuracy: 0.5900
Validation Loss: 1.9756, Validation Accuracy: 0.3927


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 489/500, Loss: 1.1656, Accuracy: 0.6013
Validation Loss: 1.9926, Validation Accuracy: 0.3871


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 490/500, Loss: 1.1920, Accuracy: 0.6002
Validation Loss: 2.1535, Validation Accuracy: 0.3562


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 491/500, Loss: 1.1709, Accuracy: 0.6118
Validation Loss: 1.9474, Validation Accuracy: 0.4039


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 492/500, Loss: 1.1681, Accuracy: 0.6146
Validation Loss: 2.0658, Validation Accuracy: 0.3436


100%|██████████| 12/12 [00:54<00:00,  4.56s/it]


Epoch 493/500, Loss: 1.1279, Accuracy: 0.6160
Validation Loss: 1.9940, Validation Accuracy: 0.3773


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 494/500, Loss: 1.1192, Accuracy: 0.6139
Validation Loss: 2.0295, Validation Accuracy: 0.3745


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 495/500, Loss: 1.1113, Accuracy: 0.6265
Validation Loss: 2.0008, Validation Accuracy: 0.3871


100%|██████████| 12/12 [00:54<00:00,  4.55s/it]


Epoch 496/500, Loss: 1.1385, Accuracy: 0.6199
Validation Loss: 1.9878, Validation Accuracy: 0.4039


100%|██████████| 12/12 [00:54<00:00,  4.57s/it]


Epoch 497/500, Loss: 1.1124, Accuracy: 0.6234
Validation Loss: 2.0909, Validation Accuracy: 0.3773


100%|██████████| 12/12 [00:55<00:00,  4.59s/it]


Epoch 498/500, Loss: 1.1134, Accuracy: 0.6174
Validation Loss: 2.1041, Validation Accuracy: 0.3787


100%|██████████| 12/12 [00:55<00:00,  4.60s/it]


Epoch 499/500, Loss: 1.1517, Accuracy: 0.6174
Validation Loss: 1.9261, Validation Accuracy: 0.4053


100%|██████████| 12/12 [00:54<00:00,  4.58s/it]


Epoch 500/500, Loss: 1.1428, Accuracy: 0.6135
Validation Loss: 2.0532, Validation Accuracy: 0.3590




In [3]:
print(best_val_acc)

0.4165497896213184


In [2]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTConfig, AdamW
from tqdm import tqdm
import glob
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import DataParallel

def extract_participant_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    
    for part in parts:
        if 'participant' in part:
            participant_number = part.replace('participant', '')
            return int(participant_number)  # Convert to integer
    return None  # If no participant label found

# def extract_label_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'gesture' in part:
#             return int(part.replace('gesture', ''))
#     return None
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels, augment=False):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        # Load and process the EMG data
        emg_image = load_and_process_file(file_path, self.target_length)
        
        # Apply augmentation if enabled
        if self.augment:
            emg_image = augment_data(emg_image, self.target_length)
        
        # Check if the EMG image has the correct shape
        if emg_image.shape != (3, 224, 224):
            raise ValueError(f"EMG data has shape {emg_image.shape} but expected (3, 224, 224)")

        return emg_image, label



# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (3, 224, 224) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10part'  # Update with your folder path
num_epochs = 300
# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42, stratify=labels)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels, augment=True)  # Apply augmentation during training
test_dataset = EMGDataset(test_files, target_length, test_labels, augment=False)    # No augmentation during testing

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Model configuration for ViT
config = ViTConfig(
    hidden_size=128,  # Reduced hidden size for smaller dataset
    num_attention_heads=4,  # Reduced attention heads
    num_hidden_layers=4,  # Reduced number of transformer layers
    image_size=224,
    patch_size=8,
    num_labels=len(unique_labels),
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3
)

# Instantiate the ViT model
model = ViTForImageClassification(config)
model = DataParallel(model)
# Use AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

class WarmupCosineAnnealingScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warm-up
            return [(self.last_epoch + 1) / self.warmup_epochs * base_lr for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(self.last_epoch - self.warmup_epochs) * torch.pi / (self.max_epochs - self.warmup_epochs)))
            return [self.eta_min + (base_lr - self.eta_min) * cosine_decay for base_lr in self.base_lrs]

warmup_epochs = 5
scheduler = WarmupCosineAnnealingScheduler(optimizer, warmup_epochs=warmup_epochs, max_epochs=num_epochs)
# Loss function            
# Cosine Annealing Learning Rate Scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Training loop

clip_value = 1.0  # For gradient clipping
best_val_acc = 0.0  # To track the best validation accuracy
best_model_path = 'best_model.pth'
last_model_path = 'last_model.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    torch.save(model.state_dict(), last_model_path)

    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'New best model saved with accuracy: {best_val_acc:.4f}')


  from .autonotebook import tqdm as notebook_tqdm


cuda


100%|██████████| 12/12 [01:50<00:00,  9.17s/it]


Epoch 1/300, Loss: 2.3162, Accuracy: 0.0958
Validation Loss: 2.3198, Validation Accuracy: 0.0996
New best model saved with accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 2/300, Loss: 2.3221, Accuracy: 0.1042
Validation Loss: 2.3134, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 3/300, Loss: 2.3138, Accuracy: 0.0976
Validation Loss: 2.3055, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 4/300, Loss: 2.3057, Accuracy: 0.0962
Validation Loss: 2.3036, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 5/300, Loss: 2.3089, Accuracy: 0.0972
Validation Loss: 2.3026, Validation Accuracy: 0.1010
New best model saved with accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 6/300, Loss: 2.3044, Accuracy: 0.1078
Validation Loss: 2.3027, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 7/300, Loss: 2.3054, Accuracy: 0.1000
Validation Loss: 2.3028, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 8/300, Loss: 2.3075, Accuracy: 0.0997
Validation Loss: 2.3027, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 9/300, Loss: 2.3072, Accuracy: 0.1000
Validation Loss: 2.3022, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 10/300, Loss: 2.3043, Accuracy: 0.0997
Validation Loss: 2.3024, Validation Accuracy: 0.1038
New best model saved with accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 11/300, Loss: 2.3061, Accuracy: 0.0983
Validation Loss: 2.3019, Validation Accuracy: 0.1052
New best model saved with accuracy: 0.1052


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 12/300, Loss: 2.3067, Accuracy: 0.0846
Validation Loss: 2.3024, Validation Accuracy: 0.1150
New best model saved with accuracy: 0.1150


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 13/300, Loss: 2.3044, Accuracy: 0.1085
Validation Loss: 2.3020, Validation Accuracy: 0.1276
New best model saved with accuracy: 0.1276


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 14/300, Loss: 2.3049, Accuracy: 0.1067
Validation Loss: 2.3019, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 15/300, Loss: 2.3041, Accuracy: 0.1088
Validation Loss: 2.3016, Validation Accuracy: 0.1248


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 16/300, Loss: 2.3041, Accuracy: 0.1081
Validation Loss: 2.3014, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.28s/it]


Epoch 17/300, Loss: 2.3042, Accuracy: 0.0972
Validation Loss: 2.3013, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 18/300, Loss: 2.3034, Accuracy: 0.1074
Validation Loss: 2.3014, Validation Accuracy: 0.1206


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 19/300, Loss: 2.3020, Accuracy: 0.1064
Validation Loss: 2.3014, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 20/300, Loss: 2.3065, Accuracy: 0.1014
Validation Loss: 2.3015, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 21/300, Loss: 2.3059, Accuracy: 0.0993
Validation Loss: 2.3013, Validation Accuracy: 0.1080


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 22/300, Loss: 2.3055, Accuracy: 0.1035
Validation Loss: 2.3012, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 23/300, Loss: 2.3062, Accuracy: 0.0913
Validation Loss: 2.3012, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 24/300, Loss: 2.3023, Accuracy: 0.1088
Validation Loss: 2.3012, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 25/300, Loss: 2.3060, Accuracy: 0.0969
Validation Loss: 2.3012, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 26/300, Loss: 2.3016, Accuracy: 0.1151
Validation Loss: 2.3012, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 27/300, Loss: 2.3035, Accuracy: 0.1025
Validation Loss: 2.3012, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 28/300, Loss: 2.3040, Accuracy: 0.1028
Validation Loss: 2.3012, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 29/300, Loss: 2.3047, Accuracy: 0.1085
Validation Loss: 2.3012, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 30/300, Loss: 2.3065, Accuracy: 0.0979
Validation Loss: 2.3013, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 31/300, Loss: 2.3042, Accuracy: 0.1067
Validation Loss: 2.3012, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.38s/it]


Epoch 32/300, Loss: 2.3029, Accuracy: 0.0990
Validation Loss: 2.3011, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 33/300, Loss: 2.3049, Accuracy: 0.1078
Validation Loss: 2.3010, Validation Accuracy: 0.1136


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 34/300, Loss: 2.3055, Accuracy: 0.0990
Validation Loss: 2.3012, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 35/300, Loss: 2.3019, Accuracy: 0.1092
Validation Loss: 2.3005, Validation Accuracy: 0.1332
New best model saved with accuracy: 0.1332


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 36/300, Loss: 2.3023, Accuracy: 0.1049
Validation Loss: 2.3014, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 37/300, Loss: 2.3012, Accuracy: 0.1053
Validation Loss: 2.3003, Validation Accuracy: 0.1276


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 38/300, Loss: 2.3048, Accuracy: 0.1035
Validation Loss: 2.3002, Validation Accuracy: 0.1122


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 39/300, Loss: 2.3046, Accuracy: 0.1025
Validation Loss: 2.2994, Validation Accuracy: 0.0940


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 40/300, Loss: 2.3057, Accuracy: 0.0941
Validation Loss: 2.3070, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 41/300, Loss: 2.3036, Accuracy: 0.1046
Validation Loss: 2.3037, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 42/300, Loss: 2.3032, Accuracy: 0.1085
Validation Loss: 2.2973, Validation Accuracy: 0.1150


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 43/300, Loss: 2.3015, Accuracy: 0.1049
Validation Loss: 2.2967, Validation Accuracy: 0.1206


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 44/300, Loss: 2.3017, Accuracy: 0.1099
Validation Loss: 2.2934, Validation Accuracy: 0.1388
New best model saved with accuracy: 0.1388


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 45/300, Loss: 2.3034, Accuracy: 0.1071
Validation Loss: 2.2996, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 46/300, Loss: 2.3030, Accuracy: 0.1123
Validation Loss: 2.3029, Validation Accuracy: 0.1080


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 47/300, Loss: 2.3033, Accuracy: 0.1155
Validation Loss: 2.3157, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 48/300, Loss: 2.3040, Accuracy: 0.1060
Validation Loss: 2.2907, Validation Accuracy: 0.1276


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 49/300, Loss: 2.3062, Accuracy: 0.1092
Validation Loss: 2.2975, Validation Accuracy: 0.1234


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 50/300, Loss: 2.3043, Accuracy: 0.1028
Validation Loss: 2.2945, Validation Accuracy: 0.1234


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 51/300, Loss: 2.3023, Accuracy: 0.1071
Validation Loss: 2.3132, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 52/300, Loss: 2.2996, Accuracy: 0.1106
Validation Loss: 2.3243, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 53/300, Loss: 2.3022, Accuracy: 0.1120
Validation Loss: 2.3188, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 54/300, Loss: 2.2950, Accuracy: 0.1250
Validation Loss: 2.3040, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 55/300, Loss: 2.2962, Accuracy: 0.1299
Validation Loss: 2.3099, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 56/300, Loss: 2.2970, Accuracy: 0.1165
Validation Loss: 2.3014, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 57/300, Loss: 2.3006, Accuracy: 0.1127
Validation Loss: 2.3249, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 58/300, Loss: 2.2948, Accuracy: 0.1229
Validation Loss: 2.3499, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 59/300, Loss: 2.3012, Accuracy: 0.1243
Validation Loss: 2.3581, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 60/300, Loss: 2.2980, Accuracy: 0.1116
Validation Loss: 2.3387, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 61/300, Loss: 2.3008, Accuracy: 0.1130
Validation Loss: 2.2886, Validation Accuracy: 0.1206


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 62/300, Loss: 2.3022, Accuracy: 0.1148
Validation Loss: 2.3363, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 63/300, Loss: 2.2955, Accuracy: 0.1155
Validation Loss: 2.3182, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 64/300, Loss: 2.2961, Accuracy: 0.1172
Validation Loss: 2.3427, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 65/300, Loss: 2.2940, Accuracy: 0.1281
Validation Loss: 2.3521, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 66/300, Loss: 2.2942, Accuracy: 0.1176
Validation Loss: 2.3452, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 67/300, Loss: 2.2868, Accuracy: 0.1292
Validation Loss: 2.3621, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 68/300, Loss: 2.2889, Accuracy: 0.1250
Validation Loss: 2.3564, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 69/300, Loss: 2.2943, Accuracy: 0.1186
Validation Loss: 2.3674, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 70/300, Loss: 2.2914, Accuracy: 0.1281
Validation Loss: 2.3680, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 71/300, Loss: 2.2860, Accuracy: 0.1302
Validation Loss: 2.3638, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 72/300, Loss: 2.2938, Accuracy: 0.1179
Validation Loss: 2.3710, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 73/300, Loss: 2.2901, Accuracy: 0.1257
Validation Loss: 2.3708, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 74/300, Loss: 2.2886, Accuracy: 0.1211
Validation Loss: 2.3708, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 75/300, Loss: 2.2839, Accuracy: 0.1257
Validation Loss: 2.3708, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 76/300, Loss: 2.2869, Accuracy: 0.1197
Validation Loss: 2.3712, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 77/300, Loss: 2.2909, Accuracy: 0.1193
Validation Loss: 2.3709, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 78/300, Loss: 2.2876, Accuracy: 0.1204
Validation Loss: 2.3702, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 79/300, Loss: 2.2905, Accuracy: 0.1299
Validation Loss: 2.3723, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 80/300, Loss: 2.2914, Accuracy: 0.1257
Validation Loss: 2.3680, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 81/300, Loss: 2.2878, Accuracy: 0.1246
Validation Loss: 2.3838, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 82/300, Loss: 2.2900, Accuracy: 0.1211
Validation Loss: 2.3678, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 83/300, Loss: 2.2910, Accuracy: 0.1271
Validation Loss: 2.3562, Validation Accuracy: 0.1094


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 84/300, Loss: 2.2911, Accuracy: 0.1257
Validation Loss: 2.3433, Validation Accuracy: 0.1122


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 85/300, Loss: 2.2855, Accuracy: 0.1267
Validation Loss: 2.4147, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 86/300, Loss: 2.2940, Accuracy: 0.1211
Validation Loss: 2.3964, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 87/300, Loss: 2.2878, Accuracy: 0.1299
Validation Loss: 2.3900, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 88/300, Loss: 2.2994, Accuracy: 0.1211
Validation Loss: 2.3729, Validation Accuracy: 0.1052


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 89/300, Loss: 2.2863, Accuracy: 0.1337
Validation Loss: 2.3770, Validation Accuracy: 0.1010


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 90/300, Loss: 2.2849, Accuracy: 0.1344
Validation Loss: 2.2628, Validation Accuracy: 0.1388


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 91/300, Loss: 2.3017, Accuracy: 0.1169
Validation Loss: 2.4245, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 92/300, Loss: 2.2894, Accuracy: 0.1106
Validation Loss: 2.2618, Validation Accuracy: 0.1459
New best model saved with accuracy: 0.1459


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 93/300, Loss: 2.2965, Accuracy: 0.1257
Validation Loss: 2.3816, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 94/300, Loss: 2.2774, Accuracy: 0.1390
Validation Loss: 2.4430, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 95/300, Loss: 2.2785, Accuracy: 0.1379
Validation Loss: 2.4822, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 96/300, Loss: 2.3092, Accuracy: 0.1211
Validation Loss: 2.4183, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 97/300, Loss: 2.2800, Accuracy: 0.1323
Validation Loss: 2.3260, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 98/300, Loss: 2.2669, Accuracy: 0.1485
Validation Loss: 2.4189, Validation Accuracy: 0.1024


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 99/300, Loss: 2.2880, Accuracy: 0.1292
Validation Loss: 2.4172, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 100/300, Loss: 2.2703, Accuracy: 0.1351
Validation Loss: 2.3587, Validation Accuracy: 0.0996


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 101/300, Loss: 2.2714, Accuracy: 0.1281
Validation Loss: 2.3782, Validation Accuracy: 0.1038


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 102/300, Loss: 2.2839, Accuracy: 0.1313
Validation Loss: 2.2461, Validation Accuracy: 0.1529
New best model saved with accuracy: 0.1529


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 103/300, Loss: 2.2560, Accuracy: 0.1422
Validation Loss: 2.2717, Validation Accuracy: 0.1388


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 104/300, Loss: 2.2616, Accuracy: 0.1471
Validation Loss: 2.2768, Validation Accuracy: 0.1543
New best model saved with accuracy: 0.1543


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 105/300, Loss: 2.2492, Accuracy: 0.1481
Validation Loss: 2.2788, Validation Accuracy: 0.1459


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 106/300, Loss: 2.2535, Accuracy: 0.1520
Validation Loss: 2.2993, Validation Accuracy: 0.1318


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 107/300, Loss: 2.2300, Accuracy: 0.1604
Validation Loss: 2.2393, Validation Accuracy: 0.1613
New best model saved with accuracy: 0.1613


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 108/300, Loss: 2.2615, Accuracy: 0.1443
Validation Loss: 2.2825, Validation Accuracy: 0.1515


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 109/300, Loss: 2.2236, Accuracy: 0.1583
Validation Loss: 2.2241, Validation Accuracy: 0.1725
New best model saved with accuracy: 0.1725


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 110/300, Loss: 2.2515, Accuracy: 0.1601
Validation Loss: 2.2230, Validation Accuracy: 0.1711


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 111/300, Loss: 2.2308, Accuracy: 0.1706
Validation Loss: 2.3368, Validation Accuracy: 0.1164


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 112/300, Loss: 2.2217, Accuracy: 0.1583
Validation Loss: 2.3326, Validation Accuracy: 0.1318


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 113/300, Loss: 2.2083, Accuracy: 0.1832
Validation Loss: 2.2197, Validation Accuracy: 0.1795
New best model saved with accuracy: 0.1795


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 114/300, Loss: 2.2105, Accuracy: 0.1748
Validation Loss: 2.3139, Validation Accuracy: 0.1571


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 115/300, Loss: 2.2054, Accuracy: 0.1776
Validation Loss: 2.2731, Validation Accuracy: 0.1571


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 116/300, Loss: 2.2079, Accuracy: 0.1622
Validation Loss: 2.3682, Validation Accuracy: 0.1276


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 117/300, Loss: 2.2350, Accuracy: 0.1611
Validation Loss: 2.2888, Validation Accuracy: 0.1459


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 118/300, Loss: 2.2083, Accuracy: 0.1878
Validation Loss: 2.3189, Validation Accuracy: 0.1403


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 119/300, Loss: 2.2168, Accuracy: 0.1766
Validation Loss: 2.2826, Validation Accuracy: 0.1585


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 120/300, Loss: 2.1919, Accuracy: 0.1864
Validation Loss: 2.2549, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 121/300, Loss: 2.1811, Accuracy: 0.1902
Validation Loss: 2.2423, Validation Accuracy: 0.1753


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 122/300, Loss: 2.1842, Accuracy: 0.1860
Validation Loss: 2.2379, Validation Accuracy: 0.1711


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 123/300, Loss: 2.1864, Accuracy: 0.1966
Validation Loss: 2.2357, Validation Accuracy: 0.1683


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 124/300, Loss: 2.1739, Accuracy: 0.2025
Validation Loss: 2.2352, Validation Accuracy: 0.1683


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 125/300, Loss: 2.1772, Accuracy: 0.1952
Validation Loss: 2.2350, Validation Accuracy: 0.1683


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 126/300, Loss: 2.1813, Accuracy: 0.1931
Validation Loss: 2.2276, Validation Accuracy: 0.1795


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 127/300, Loss: 2.1818, Accuracy: 0.2022
Validation Loss: 2.2390, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 128/300, Loss: 2.1837, Accuracy: 0.1994
Validation Loss: 2.2077, Validation Accuracy: 0.1978
New best model saved with accuracy: 0.1978


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 129/300, Loss: 2.1829, Accuracy: 0.1941
Validation Loss: 2.2284, Validation Accuracy: 0.1950


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 130/300, Loss: 2.1820, Accuracy: 0.1931
Validation Loss: 2.2718, Validation Accuracy: 0.1599


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 131/300, Loss: 2.1745, Accuracy: 0.1878
Validation Loss: 2.2210, Validation Accuracy: 0.1585


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 132/300, Loss: 2.1785, Accuracy: 0.1832
Validation Loss: 2.2567, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 133/300, Loss: 2.1911, Accuracy: 0.1843
Validation Loss: 2.1901, Validation Accuracy: 0.2174
New best model saved with accuracy: 0.2174


100%|██████████| 12/12 [00:51<00:00,  4.28s/it]


Epoch 134/300, Loss: 2.1932, Accuracy: 0.1874
Validation Loss: 2.2924, Validation Accuracy: 0.1473


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 135/300, Loss: 2.2213, Accuracy: 0.1822
Validation Loss: 2.1755, Validation Accuracy: 0.2202
New best model saved with accuracy: 0.2202


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 136/300, Loss: 2.2094, Accuracy: 0.1857
Validation Loss: 2.2106, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 137/300, Loss: 2.1890, Accuracy: 0.1864
Validation Loss: 2.3220, Validation Accuracy: 0.1445


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 138/300, Loss: 2.1974, Accuracy: 0.1878
Validation Loss: 2.1957, Validation Accuracy: 0.2202


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 139/300, Loss: 2.1741, Accuracy: 0.2032
Validation Loss: 2.2538, Validation Accuracy: 0.1753


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 140/300, Loss: 2.2212, Accuracy: 0.1776
Validation Loss: 2.3504, Validation Accuracy: 0.1206


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 141/300, Loss: 2.2101, Accuracy: 0.1759
Validation Loss: 2.2415, Validation Accuracy: 0.1571


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 142/300, Loss: 2.2199, Accuracy: 0.1797
Validation Loss: 2.2344, Validation Accuracy: 0.1669


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 143/300, Loss: 2.2472, Accuracy: 0.1622
Validation Loss: 2.2248, Validation Accuracy: 0.1767


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 144/300, Loss: 2.2290, Accuracy: 0.1622
Validation Loss: 2.2574, Validation Accuracy: 0.1585


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 145/300, Loss: 2.1890, Accuracy: 0.1822
Validation Loss: 2.1784, Validation Accuracy: 0.2132


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 146/300, Loss: 2.1729, Accuracy: 0.2001
Validation Loss: 2.2475, Validation Accuracy: 0.1683


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 147/300, Loss: 2.1824, Accuracy: 0.1955
Validation Loss: 2.1842, Validation Accuracy: 0.2146


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 148/300, Loss: 2.2096, Accuracy: 0.1727
Validation Loss: 2.2998, Validation Accuracy: 0.1403


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 149/300, Loss: 2.1819, Accuracy: 0.1744
Validation Loss: 2.2398, Validation Accuracy: 0.1599


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 150/300, Loss: 2.1637, Accuracy: 0.2067
Validation Loss: 2.2596, Validation Accuracy: 0.1655


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 151/300, Loss: 2.1948, Accuracy: 0.1902
Validation Loss: 2.2298, Validation Accuracy: 0.1697


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 152/300, Loss: 2.2409, Accuracy: 0.1664
Validation Loss: 2.3106, Validation Accuracy: 0.1585


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 153/300, Loss: 2.1955, Accuracy: 0.1769
Validation Loss: 2.3062, Validation Accuracy: 0.1360


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 154/300, Loss: 2.2383, Accuracy: 0.1773
Validation Loss: 2.1711, Validation Accuracy: 0.2146


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 155/300, Loss: 2.1502, Accuracy: 0.2236
Validation Loss: 2.1721, Validation Accuracy: 0.2076


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 156/300, Loss: 2.1391, Accuracy: 0.2134
Validation Loss: 2.1437, Validation Accuracy: 0.2202


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 157/300, Loss: 2.1448, Accuracy: 0.2162
Validation Loss: 2.2175, Validation Accuracy: 0.1907


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 158/300, Loss: 2.1575, Accuracy: 0.2102
Validation Loss: 2.1581, Validation Accuracy: 0.2132


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 159/300, Loss: 2.1279, Accuracy: 0.2120
Validation Loss: 2.1462, Validation Accuracy: 0.2258
New best model saved with accuracy: 0.2258


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 160/300, Loss: 2.1370, Accuracy: 0.2152
Validation Loss: 2.1266, Validation Accuracy: 0.2440
New best model saved with accuracy: 0.2440


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 161/300, Loss: 2.1211, Accuracy: 0.2155
Validation Loss: 2.1361, Validation Accuracy: 0.2230


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 162/300, Loss: 2.0987, Accuracy: 0.2492
Validation Loss: 2.1085, Validation Accuracy: 0.2805
New best model saved with accuracy: 0.2805


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 163/300, Loss: 2.0868, Accuracy: 0.2513
Validation Loss: 2.1180, Validation Accuracy: 0.2342


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 164/300, Loss: 2.0704, Accuracy: 0.2562
Validation Loss: 2.1312, Validation Accuracy: 0.2496


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 165/300, Loss: 2.0828, Accuracy: 0.2569
Validation Loss: 2.1242, Validation Accuracy: 0.2567


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 166/300, Loss: 2.0552, Accuracy: 0.2748
Validation Loss: 2.0837, Validation Accuracy: 0.2721


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 167/300, Loss: 2.0414, Accuracy: 0.2790
Validation Loss: 2.0847, Validation Accuracy: 0.2847
New best model saved with accuracy: 0.2847


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 168/300, Loss: 2.0261, Accuracy: 0.2871
Validation Loss: 2.0772, Validation Accuracy: 0.2861
New best model saved with accuracy: 0.2861


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 169/300, Loss: 2.0225, Accuracy: 0.2783
Validation Loss: 2.0719, Validation Accuracy: 0.2833


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 170/300, Loss: 2.0063, Accuracy: 0.2840
Validation Loss: 2.0622, Validation Accuracy: 0.2889
New best model saved with accuracy: 0.2889


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 171/300, Loss: 2.0209, Accuracy: 0.2780
Validation Loss: 2.0641, Validation Accuracy: 0.2847


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 172/300, Loss: 2.0094, Accuracy: 0.2871
Validation Loss: 2.0625, Validation Accuracy: 0.2861


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 173/300, Loss: 2.0149, Accuracy: 0.2938
Validation Loss: 2.0625, Validation Accuracy: 0.2875


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 174/300, Loss: 2.0227, Accuracy: 0.2920
Validation Loss: 2.0615, Validation Accuracy: 0.2861


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 175/300, Loss: 2.0200, Accuracy: 0.2861
Validation Loss: 2.0659, Validation Accuracy: 0.2875


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 176/300, Loss: 2.0189, Accuracy: 0.2833
Validation Loss: 2.0638, Validation Accuracy: 0.2847


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 177/300, Loss: 2.0245, Accuracy: 0.2976
Validation Loss: 2.0898, Validation Accuracy: 0.2693


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 178/300, Loss: 2.0095, Accuracy: 0.2819
Validation Loss: 2.1055, Validation Accuracy: 0.2581


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 179/300, Loss: 2.0129, Accuracy: 0.2920
Validation Loss: 2.0561, Validation Accuracy: 0.3043
New best model saved with accuracy: 0.3043


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 180/300, Loss: 2.0359, Accuracy: 0.2868
Validation Loss: 2.0596, Validation Accuracy: 0.2875


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 181/300, Loss: 2.0070, Accuracy: 0.2882
Validation Loss: 2.1103, Validation Accuracy: 0.2511


100%|██████████| 12/12 [00:52<00:00,  4.37s/it]


Epoch 182/300, Loss: 2.0288, Accuracy: 0.2731
Validation Loss: 2.1058, Validation Accuracy: 0.2693


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 183/300, Loss: 2.0291, Accuracy: 0.2776
Validation Loss: 2.0409, Validation Accuracy: 0.3086
New best model saved with accuracy: 0.3086


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 184/300, Loss: 2.0413, Accuracy: 0.2794
Validation Loss: 2.0828, Validation Accuracy: 0.2609


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 185/300, Loss: 2.0318, Accuracy: 0.2713
Validation Loss: 2.0631, Validation Accuracy: 0.2791


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 186/300, Loss: 2.0322, Accuracy: 0.2829
Validation Loss: 2.0634, Validation Accuracy: 0.2721


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 187/300, Loss: 2.0195, Accuracy: 0.2664
Validation Loss: 2.0619, Validation Accuracy: 0.2777


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 188/300, Loss: 2.0136, Accuracy: 0.2920
Validation Loss: 2.0498, Validation Accuracy: 0.2805


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 189/300, Loss: 1.9968, Accuracy: 0.2850
Validation Loss: 2.0332, Validation Accuracy: 0.2889


100%|██████████| 12/12 [00:52<00:00,  4.33s/it]


Epoch 190/300, Loss: 2.0270, Accuracy: 0.2678
Validation Loss: 2.0610, Validation Accuracy: 0.2791


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 191/300, Loss: 1.9895, Accuracy: 0.2991
Validation Loss: 2.0605, Validation Accuracy: 0.2679


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 192/300, Loss: 2.0055, Accuracy: 0.2794
Validation Loss: 2.0311, Validation Accuracy: 0.2819


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 193/300, Loss: 2.0066, Accuracy: 0.2878
Validation Loss: 2.1229, Validation Accuracy: 0.2539


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 194/300, Loss: 2.0123, Accuracy: 0.2910
Validation Loss: 2.0403, Validation Accuracy: 0.2889


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 195/300, Loss: 2.0576, Accuracy: 0.2745
Validation Loss: 2.0629, Validation Accuracy: 0.2721


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 196/300, Loss: 1.9966, Accuracy: 0.2959
Validation Loss: 2.0455, Validation Accuracy: 0.2609


100%|██████████| 12/12 [00:51<00:00,  4.29s/it]


Epoch 197/300, Loss: 1.9844, Accuracy: 0.3026
Validation Loss: 2.0518, Validation Accuracy: 0.2875


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 198/300, Loss: 2.0352, Accuracy: 0.2625
Validation Loss: 2.0763, Validation Accuracy: 0.2637


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 199/300, Loss: 1.9785, Accuracy: 0.3012
Validation Loss: 2.0928, Validation Accuracy: 0.2440


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 200/300, Loss: 2.0220, Accuracy: 0.2938
Validation Loss: 1.9786, Validation Accuracy: 0.2945


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 201/300, Loss: 1.9854, Accuracy: 0.3033
Validation Loss: 2.0321, Validation Accuracy: 0.2931


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 202/300, Loss: 1.9687, Accuracy: 0.3163
Validation Loss: 2.0383, Validation Accuracy: 0.2777


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 203/300, Loss: 1.9648, Accuracy: 0.3026
Validation Loss: 1.9922, Validation Accuracy: 0.2721


100%|██████████| 12/12 [00:51<00:00,  4.28s/it]


Epoch 204/300, Loss: 1.9739, Accuracy: 0.3008
Validation Loss: 2.0255, Validation Accuracy: 0.2749


100%|██████████| 12/12 [00:51<00:00,  4.27s/it]


Epoch 205/300, Loss: 1.9928, Accuracy: 0.2913
Validation Loss: 2.0271, Validation Accuracy: 0.2889


100%|██████████| 12/12 [00:51<00:00,  4.30s/it]


Epoch 206/300, Loss: 1.9746, Accuracy: 0.2987
Validation Loss: 2.1161, Validation Accuracy: 0.2398


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 207/300, Loss: 1.9906, Accuracy: 0.2934
Validation Loss: 2.0078, Validation Accuracy: 0.3072


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 208/300, Loss: 1.9886, Accuracy: 0.2998
Validation Loss: 2.0174, Validation Accuracy: 0.3015


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 209/300, Loss: 1.9234, Accuracy: 0.3335
Validation Loss: 1.9704, Validation Accuracy: 0.2973


100%|██████████| 12/12 [00:52<00:00,  4.38s/it]


Epoch 210/300, Loss: 1.9100, Accuracy: 0.3380
Validation Loss: 2.0275, Validation Accuracy: 0.2987


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 211/300, Loss: 1.9242, Accuracy: 0.3384
Validation Loss: 1.9454, Validation Accuracy: 0.3072


100%|██████████| 12/12 [00:56<00:00,  4.68s/it]


Epoch 212/300, Loss: 1.9049, Accuracy: 0.3299
Validation Loss: 1.9366, Validation Accuracy: 0.3338
New best model saved with accuracy: 0.3338


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 213/300, Loss: 1.9277, Accuracy: 0.3380
Validation Loss: 1.9426, Validation Accuracy: 0.3296


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 214/300, Loss: 1.8782, Accuracy: 0.3542
Validation Loss: 1.9361, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 215/300, Loss: 1.8673, Accuracy: 0.3573
Validation Loss: 1.9315, Validation Accuracy: 0.3254


100%|██████████| 12/12 [00:52<00:00,  4.38s/it]


Epoch 216/300, Loss: 1.8614, Accuracy: 0.3556
Validation Loss: 1.9458, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 217/300, Loss: 1.8599, Accuracy: 0.3643
Validation Loss: 1.9410, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Epoch 218/300, Loss: 1.8635, Accuracy: 0.3636
Validation Loss: 1.9297, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 219/300, Loss: 1.8418, Accuracy: 0.3763
Validation Loss: 1.9319, Validation Accuracy: 0.3282


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 220/300, Loss: 1.8374, Accuracy: 0.3657
Validation Loss: 1.9310, Validation Accuracy: 0.3394
New best model saved with accuracy: 0.3394


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 221/300, Loss: 1.8404, Accuracy: 0.3693
Validation Loss: 1.9320, Validation Accuracy: 0.3352


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 222/300, Loss: 1.8393, Accuracy: 0.3735
Validation Loss: 1.9321, Validation Accuracy: 0.3338


100%|██████████| 12/12 [00:51<00:00,  4.31s/it]


Epoch 223/300, Loss: 1.8239, Accuracy: 0.3819
Validation Loss: 1.9311, Validation Accuracy: 0.3366


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 224/300, Loss: 1.8414, Accuracy: 0.3707
Validation Loss: 1.9311, Validation Accuracy: 0.3352


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 225/300, Loss: 1.8501, Accuracy: 0.3787
Validation Loss: 1.9287, Validation Accuracy: 0.3338


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 226/300, Loss: 1.8468, Accuracy: 0.3756
Validation Loss: 1.9321, Validation Accuracy: 0.3268


100%|██████████| 12/12 [00:52<00:00,  4.37s/it]


Epoch 227/300, Loss: 1.8191, Accuracy: 0.3826
Validation Loss: 1.9410, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 228/300, Loss: 1.8334, Accuracy: 0.3749
Validation Loss: 1.9334, Validation Accuracy: 0.3282


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 229/300, Loss: 1.8485, Accuracy: 0.3640
Validation Loss: 1.9278, Validation Accuracy: 0.3212


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 230/300, Loss: 1.8432, Accuracy: 0.3703
Validation Loss: 1.9606, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 231/300, Loss: 1.8574, Accuracy: 0.3622
Validation Loss: 1.9160, Validation Accuracy: 0.3464
New best model saved with accuracy: 0.3464


100%|██████████| 12/12 [00:52<00:00,  4.37s/it]


Epoch 232/300, Loss: 1.8522, Accuracy: 0.3601
Validation Loss: 1.9776, Validation Accuracy: 0.3310


100%|██████████| 12/12 [00:52<00:00,  4.38s/it]


Epoch 233/300, Loss: 1.8717, Accuracy: 0.3524
Validation Loss: 2.0013, Validation Accuracy: 0.2917


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 234/300, Loss: 1.8660, Accuracy: 0.3566
Validation Loss: 1.9905, Validation Accuracy: 0.3043


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 235/300, Loss: 1.9015, Accuracy: 0.3426
Validation Loss: 1.9435, Validation Accuracy: 0.3198


100%|██████████| 12/12 [00:52<00:00,  4.35s/it]


Epoch 236/300, Loss: 1.8661, Accuracy: 0.3499
Validation Loss: 1.9241, Validation Accuracy: 0.3338


100%|██████████| 12/12 [00:52<00:00,  4.34s/it]


Epoch 237/300, Loss: 1.8442, Accuracy: 0.3700
Validation Loss: 2.0501, Validation Accuracy: 0.2693


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 238/300, Loss: 1.8906, Accuracy: 0.3345
Validation Loss: 2.0625, Validation Accuracy: 0.2889


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 239/300, Loss: 1.9171, Accuracy: 0.3310
Validation Loss: 1.9170, Validation Accuracy: 0.3352


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 240/300, Loss: 1.8835, Accuracy: 0.3492
Validation Loss: 1.9739, Validation Accuracy: 0.3029


100%|██████████| 12/12 [00:52<00:00,  4.37s/it]


Epoch 241/300, Loss: 1.8742, Accuracy: 0.3570
Validation Loss: 1.9035, Validation Accuracy: 0.3422


100%|██████████| 12/12 [00:51<00:00,  4.33s/it]


Epoch 242/300, Loss: 1.8779, Accuracy: 0.3675
Validation Loss: 1.9533, Validation Accuracy: 0.3128


100%|██████████| 12/12 [00:52<00:00,  4.36s/it]


Epoch 243/300, Loss: 1.8496, Accuracy: 0.3570
Validation Loss: 2.0039, Validation Accuracy: 0.2973


 58%|█████▊    | 7/12 [00:34<00:24,  4.90s/it]


KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob

# Function to extract the participant ID from the filename
# def extract_participant_from_filename(filepath):
#     filename = os.path.basename(filepath)
#     parts = filename.split('_')
#     for part in parts:
#         if 'participant' in part:
#             participant_id = part.replace('participant', '')
#             return int(participant_id)  # Convert to integer
#     return None  # If no participant ID found

# # Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data

  from .autonotebook import tqdm as notebook_tqdm


In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.3, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTForImageClassification, ViTConfig
from tqdm import tqdm

# Define the number of classes (in this case, the number of participants)
num_classes = len(np.unique(labels))

config = ViTConfig(
    num_hidden_layers=6,  # Reduced number of layers
    hidden_size=512,      # Smaller hidden size
    num_attention_heads=4,  # Fewer attention heads
    patch_size=16,        # Smaller patch size
    intermediate_size=1024,  # Reduced intermediate layer size
    image_size=224,       # Input image size
    num_labels=num_classes,         # Number of output classes
    hidden_dropout_prob=0.3,  # Adding dropout to regularize
    attention_probs_dropout_prob=0.3
)

# Initialize ViT model from scratch with this config
model = ViTForImageClassification(config)

# Define optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 30  # Set the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader):
        # Move data to the same device as the model
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images).logits
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # Evaluation after every epoch (optional)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy after epoch {epoch+1}: {accuracy:.2f}%")

# Save the trained model
torch.save(model.state_dict(), 'vit_emg_model.pth')


100%|██████████| 78/78 [01:36<00:00,  1.23s/it]


Epoch [1/100], Loss: 2.9389
Test Accuracy after epoch 1: 5.43%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [2/100], Loss: 2.8926
Test Accuracy after epoch 2: 6.64%


100%|██████████| 78/78 [00:40<00:00,  1.90it/s]


Epoch [3/100], Loss: 2.8774
Test Accuracy after epoch 3: 4.58%


100%|██████████| 78/78 [00:39<00:00,  1.95it/s]


Epoch [4/100], Loss: 2.8714
Test Accuracy after epoch 4: 6.36%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [5/100], Loss: 2.8526
Test Accuracy after epoch 5: 7.30%


100%|██████████| 78/78 [00:40<00:00,  1.91it/s]


Epoch [6/100], Loss: 2.8157
Test Accuracy after epoch 6: 6.17%


100%|██████████| 78/78 [00:40<00:00,  1.90it/s]


Epoch [7/100], Loss: 2.8185
Test Accuracy after epoch 7: 6.83%


100%|██████████| 78/78 [00:40<00:00,  1.93it/s]


Epoch [8/100], Loss: 2.7884
Test Accuracy after epoch 8: 5.61%


100%|██████████| 78/78 [00:40<00:00,  1.92it/s]


Epoch [9/100], Loss: 2.7815
Test Accuracy after epoch 9: 5.43%


 87%|████████▋ | 68/78 [00:35<00:05,  1.91it/s]


KeyboardInterrupt: 

In [None]:
print(f"Model output classes: {model.config.num_labels}")
print(f"Unique labels in dataset: {np.unique(train_labels)}")

Model output classes: 10
Unique labels in dataset: [0 1 2 3 4 5 6 7 8 9]


In [None]:
# --- Training loop for 50 epochs ---
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 1, Loss: 2.2416, Accuracy: 0.1794
Test accuracy: 0.2553


Epoch 2/50: 100%|██████████| 90/90 [01:12<00:00,  1.25it/s]


Epoch 2, Loss: 1.8933, Accuracy: 0.3556
Test accuracy: 0.3955


Epoch 3/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 3, Loss: 1.6035, Accuracy: 0.4693
Test accuracy: 0.4320


Epoch 4/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 4, Loss: 1.3069, Accuracy: 0.5781
Test accuracy: 0.4600


Epoch 5/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 5, Loss: 1.0768, Accuracy: 0.6511
Test accuracy: 0.5470


Epoch 6/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 6, Loss: 0.6955, Accuracy: 0.7943
Test accuracy: 0.5330


Epoch 7/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 7, Loss: 0.4704, Accuracy: 0.8635
Test accuracy: 0.5259


Epoch 8/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 8, Loss: 0.3046, Accuracy: 0.9210
Test accuracy: 0.5498


Epoch 9/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 9, Loss: 0.2185, Accuracy: 0.9438
Test accuracy: 0.5512


Epoch 10/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 10, Loss: 0.1618, Accuracy: 0.9582
Test accuracy: 0.4867


Epoch 11/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 11, Loss: 0.1546, Accuracy: 0.9596
Test accuracy: 0.5456


Epoch 12/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 12, Loss: 0.1435, Accuracy: 0.9603
Test accuracy: 0.5652


Epoch 13/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 13, Loss: 0.0731, Accuracy: 0.9849
Test accuracy: 0.4923


Epoch 14/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 14, Loss: 0.1267, Accuracy: 0.9645
Test accuracy: 0.4979


Epoch 15/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 15, Loss: 0.0674, Accuracy: 0.9835
Test accuracy: 0.5302


Epoch 16/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 16, Loss: 0.0767, Accuracy: 0.9796
Test accuracy: 0.5666


Epoch 17/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 17, Loss: 0.0595, Accuracy: 0.9849
Test accuracy: 0.5891


Epoch 18/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 18, Loss: 0.0322, Accuracy: 0.9940
Test accuracy: 0.5428


Epoch 19/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 19, Loss: 0.0548, Accuracy: 0.9839
Test accuracy: 0.5680


Epoch 20/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 20, Loss: 0.0505, Accuracy: 0.9856
Test accuracy: 0.5666


Epoch 21/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 21, Loss: 0.0864, Accuracy: 0.9744
Test accuracy: 0.5722


Epoch 22/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 22, Loss: 0.1016, Accuracy: 0.9698
Test accuracy: 0.4769


Epoch 23/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 23, Loss: 0.0759, Accuracy: 0.9772
Test accuracy: 0.5330


Epoch 24/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 24, Loss: 0.0339, Accuracy: 0.9933
Test accuracy: 0.5498


Epoch 25/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 25, Loss: 0.1114, Accuracy: 0.9705
Test accuracy: 0.5652


Epoch 26/50: 100%|██████████| 90/90 [01:11<00:00,  1.27it/s]


Epoch 26, Loss: 0.0147, Accuracy: 0.9972
Test accuracy: 0.5863


Epoch 27/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 27, Loss: 0.0083, Accuracy: 0.9993
Test accuracy: 0.5806


Epoch 28/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 28, Loss: 0.0066, Accuracy: 0.9993
Test accuracy: 0.5849


Epoch 29/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 29, Loss: 0.0051, Accuracy: 0.9996
Test accuracy: 0.5849


Epoch 30/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 30, Loss: 0.0035, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 31/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 31, Loss: 0.0031, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 32/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 32, Loss: 0.0028, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 33/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 33, Loss: 0.0026, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 34/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 34, Loss: 0.0024, Accuracy: 1.0000
Test accuracy: 0.5877


Epoch 35/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 35, Loss: 0.0022, Accuracy: 1.0000
Test accuracy: 0.5891


Epoch 36/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 36, Loss: 0.0021, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 37/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 37, Loss: 0.0019, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 38/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 38, Loss: 0.0018, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 39/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 39, Loss: 0.0017, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 40/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 40, Loss: 0.0016, Accuracy: 1.0000
Test accuracy: 0.5905


Epoch 41/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 41, Loss: 0.0015, Accuracy: 1.0000
Test accuracy: 0.5919


Epoch 42/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 42, Loss: 0.0014, Accuracy: 1.0000
Test accuracy: 0.5933


Epoch 43/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 43, Loss: 0.0013, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 44/50: 100%|██████████| 90/90 [01:11<00:00,  1.25it/s]


Epoch 44, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 45/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 45, Loss: 0.0012, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 46/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 46, Loss: 0.0011, Accuracy: 1.0000
Test accuracy: 0.5947


Epoch 47/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 47, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 48/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 48, Loss: 0.0010, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 49/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 49, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5961


Epoch 50/50: 100%|██████████| 90/90 [01:11<00:00,  1.26it/s]


Epoch 50, Loss: 0.0009, Accuracy: 1.0000
Test accuracy: 0.5975


In [None]:


# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_participant_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 45/45 [00:36<00:00,  1.25it/s]


Epoch 1, Loss: 1.4741, Accuracy: 0.3455
Test accuracy: 0.4566


Epoch 2/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 2, Loss: 1.1747, Accuracy: 0.5407
Test accuracy: 0.5770


Epoch 3/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 3, Loss: 0.8848, Accuracy: 0.6713
Test accuracy: 0.6751


Epoch 4/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 4, Loss: 0.6088, Accuracy: 0.8062
Test accuracy: 0.7171


Epoch 5/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 5, Loss: 0.4332, Accuracy: 0.8581
Test accuracy: 0.6303


Epoch 6/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 6, Loss: 0.2952, Accuracy: 0.9094
Test accuracy: 0.7423


Epoch 7/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 7, Loss: 0.1660, Accuracy: 0.9565
Test accuracy: 0.7171


Epoch 8/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 8, Loss: 0.1073, Accuracy: 0.9719
Test accuracy: 0.7003


Epoch 9/20: 100%|██████████| 45/45 [00:35<00:00,  1.28it/s]


Epoch 9, Loss: 0.1327, Accuracy: 0.9621
Test accuracy: 0.7283


Epoch 10/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 10, Loss: 0.0993, Accuracy: 0.9747
Test accuracy: 0.7563


Epoch 11/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 11, Loss: 0.0724, Accuracy: 0.9831
Test accuracy: 0.6947


Epoch 12/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 12, Loss: 0.0463, Accuracy: 0.9902
Test accuracy: 0.7115


Epoch 13/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 13, Loss: 0.0474, Accuracy: 0.9881
Test accuracy: 0.7283


Epoch 14/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 14, Loss: 0.0936, Accuracy: 0.9747
Test accuracy: 0.7115


Epoch 15/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 15, Loss: 0.1236, Accuracy: 0.9621
Test accuracy: 0.6499


Epoch 16/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 16, Loss: 0.1496, Accuracy: 0.9537
Test accuracy: 0.6779


Epoch 17/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 17, Loss: 0.0414, Accuracy: 0.9909
Test accuracy: 0.7619


Epoch 18/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 18, Loss: 0.0188, Accuracy: 0.9965
Test accuracy: 0.7759


Epoch 19/20: 100%|██████████| 45/45 [00:35<00:00,  1.26it/s]


Epoch 19, Loss: 0.0172, Accuracy: 0.9986
Test accuracy: 0.7731


Epoch 20/20: 100%|██████████| 45/45 [00:35<00:00,  1.27it/s]


Epoch 20, Loss: 0.0108, Accuracy: 0.9993
Test accuracy: 0.7675


In [None]:
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]

    # Reshape to 224x224x3 required by ViT (ensure the size is consistent)
    padded_data = padded_data.flatten()  # Flatten before reshaping
    num_pixels = 3 * 224 * 224 
    
    # If we don't have enough data, we can pad with zeros; otherwise, truncate
    if padded_data.size < num_pixels:
        reshaped_data = np.pad(padded_data, (0, num_pixels - padded_data.size), 'constant', constant_values=0)
    else:
        reshaped_data = padded_data[:num_pixels]
    
    # Reshape to (224, 224, 3) as required by ViT
    reshaped_data = np.reshape(reshaped_data, (3, 224, 224))
    
    # Normalize data to range [0, 1]
    reshaped_data = (reshaped_data - reshaped_data.min()) / (reshaped_data.max() - reshaped_data.min())

    return reshaped_data    

In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_5gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=5)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 113/113 [03:05<00:00,  1.64s/it]


Epoch 1, Loss: 1.0156, Accuracy: 0.5933
Test accuracy: 0.6951


Epoch 2/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 2, Loss: 0.6677, Accuracy: 0.7506
Test accuracy: 0.7816


Epoch 3/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 3, Loss: 0.4857, Accuracy: 0.8227
Test accuracy: 0.7627


Epoch 4/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 4, Loss: 0.3735, Accuracy: 0.8713
Test accuracy: 0.8104


Epoch 5/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 5, Loss: 0.2388, Accuracy: 0.9223
Test accuracy: 0.7938


Epoch 6/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 6, Loss: 0.1578, Accuracy: 0.9503
Test accuracy: 0.7993


Epoch 7/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 7, Loss: 0.1318, Accuracy: 0.9603
Test accuracy: 0.7860


Epoch 8/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 8, Loss: 0.0938, Accuracy: 0.9734
Test accuracy: 0.8226


Epoch 9/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 9, Loss: 0.0645, Accuracy: 0.9803
Test accuracy: 0.7672


Epoch 10/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 10, Loss: 0.0746, Accuracy: 0.9761
Test accuracy: 0.7960


Epoch 11/20: 100%|██████████| 113/113 [01:29<00:00,  1.26it/s]


Epoch 11, Loss: 0.0330, Accuracy: 0.9917
Test accuracy: 0.8149


Epoch 12/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 12, Loss: 0.0445, Accuracy: 0.9889
Test accuracy: 0.8104


Epoch 13/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 13, Loss: 0.0661, Accuracy: 0.9792
Test accuracy: 0.8115


Epoch 14/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 14, Loss: 0.0471, Accuracy: 0.9850
Test accuracy: 0.8259


Epoch 15/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 15, Loss: 0.0220, Accuracy: 0.9945
Test accuracy: 0.8082


Epoch 16/20: 100%|██████████| 113/113 [01:28<00:00,  1.27it/s]


Epoch 16, Loss: 0.0149, Accuracy: 0.9969
Test accuracy: 0.8193


Epoch 17/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 17, Loss: 0.0469, Accuracy: 0.9859
Test accuracy: 0.7971


Epoch 18/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 18, Loss: 0.0401, Accuracy: 0.9870
Test accuracy: 0.7716


Epoch 19/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 19, Loss: 0.0242, Accuracy: 0.9931
Test accuracy: 0.8049


Epoch 20/20: 100%|██████████| 113/113 [01:29<00:00,  1.27it/s]


Epoch 20, Loss: 0.0091, Accuracy: 0.9978
Test accuracy: 0.8060


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import glob
# Custom dataset for the EMG data
class EMGDataset(Dataset):
    def __init__(self, file_list, target_length, labels):
        self.file_list = file_list
        self.target_length = target_length
        self.labels = labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        emg_image = load_and_process_file(file_path, self.target_length)
        return emg_image, label

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_10gest'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Original unique labels
unique_labels = np.unique(labels)

# Create a mapping from the original labels to the new range 0-4
label_mapping = {original_label: new_label for new_label, original_label in enumerate(unique_labels)}

# Convert the original labels to the new range using the mapping
converted_labels = np.array([label_mapping[label] for label in labels])

labels = converted_labels
# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, labels, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = EMGDataset(train_files, target_length, train_labels)
test_dataset = EMGDataset(test_files, target_length, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Load a pre-trained ViT model from Hugging Face
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)  # Adjust num_labels as needed

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# --- Training loop for 20 epochs ---
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for emg_data, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # emg_data = torch.tensor(emg_data).float().to(device)  # Convert to PyTorch tensor and move to device
        # labels = torch.tensor(labels).long().to(device)  # Convert labels to tensor and move to device
        emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
        labels = labels.to(device)  # Convert labels to tensor and move to device
        # Prepare input for ViT by treating EMG data as image-like input
        inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)

        
        # Forward pass
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/total:.4f}')

    # --- Testing loop ---
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for emg_data, labels in test_loader:
            emg_data = emg_data.to(device)  # Convert to PyTorch tensor and move to device
            labels = labels.to(device)  # Convert labels to tensor and move to device

            inputs = feature_extractor(emg_data, return_tensors="pt", do_rescale=False)['pixel_values'].to(device)
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_accuracy = correct / total
        print(f'Test accuracy: {test_accuracy:.4f}')


Epoch 1/20: 100%|██████████| 226/226 [06:29<00:00,  1.72s/it]


Epoch 1, Loss: 1.6487, Accuracy: 0.4242
Test accuracy: 0.5774


Epoch 2/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 2, Loss: 1.1021, Accuracy: 0.6264
Test accuracy: 0.6051


Epoch 3/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 3, Loss: 0.8969, Accuracy: 0.6896
Test accuracy: 0.6794


Epoch 4/20: 100%|██████████| 226/226 [02:59<00:00,  1.26it/s]


Epoch 4, Loss: 0.6698, Accuracy: 0.7712
Test accuracy: 0.6589


Epoch 5/20: 100%|██████████| 226/226 [02:58<00:00,  1.26it/s]


Epoch 5, Loss: 0.5011, Accuracy: 0.8303
Test accuracy: 0.6744


Epoch 6/20: 100%|██████████| 226/226 [02:58<00:00,  1.27it/s]


Epoch 6, Loss: 0.3302, Accuracy: 0.8935
Test accuracy: 0.6384


Epoch 7/20:  92%|█████████▏| 208/226 [02:45<00:14,  1.26it/s]


KeyboardInterrupt: 

In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import glob

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

In [None]:

# Path to the root folder where all .mat files are stored
root_folder = 's1_s2_s3_matfiles_20part'  # Update with your folder path

# Find all .mat files in the folder
all_mat_files = glob.glob(os.path.join(root_folder, '*.mat'))

# Collect all corresponding labels
all_labels = [extract_label_from_filename(mat_file) for mat_file in all_mat_files]

# Target length for EMG signals (adjust based on your data)
target_length = 10240  # Modify this as needed

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)
print(f"Shape of X_train: {X_train.shape}")

Shape of X_train: (5701, 327680)


In [None]:
pca = PCA(n_components = 2000)
# Step 2: Fit PCA on the dataset
pca.fit(X_train)
# Step 3: Calculate the cumulative variance explained by the components
cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
# Step 4: Find the number of components that explain at least 95% variance
n_components_95 = np.argmax(cumulative_variance >= 0.95) + 1

print(f"Number of components to retain 95% of variance: {n_components_95}")

# Print explained variance for each component
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance by each component: {explained_variance}")

# Calculate cumulative explained variance
cumulative_variance = np.cumsum(explained_variance)
print(f"Cumulative explained variance: {cumulative_variance}")

Number of components to retain 95% of variance: 1
Explained variance by each component: [0.02734777 0.02461403 0.00818816 0.0069905  0.00584958 0.00445974
 0.00403911 0.00369412 0.00349787 0.00316976 0.00300401 0.00289899
 0.00279481 0.00275903 0.00264407 0.00261821 0.00259447 0.00257349
 0.00248961 0.00239943 0.0023721  0.00232195 0.00227801 0.00223757
 0.0022159  0.00220089 0.00212229 0.00206722 0.00201354 0.00198255
 0.00194987 0.00192753 0.00191069 0.00189586 0.00187392 0.00184805
 0.00183598 0.00181051 0.00178506 0.00175982 0.0017539  0.00172137
 0.00170902 0.00169332 0.00167088 0.00165842 0.00165174 0.0016362
 0.00163023 0.00160764 0.00159447 0.00155964 0.00155086 0.00152713
 0.00146956 0.00145834 0.00141764 0.00141322 0.00140813 0.00138163
 0.00136974 0.00136623 0.00134497 0.00133781 0.00131482 0.00130781
 0.00129883 0.00128668 0.00127945 0.00126588 0.00125526 0.00123351
 0.00122743 0.00121431 0.00120742 0.00119956 0.00119493 0.00118896
 0.0011854  0.00117166 0.00116383 0.001154

In [None]:


# Initialize PCA for dimensionality reduction
#pca = PCA(n_components=4)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.3065


In [None]:
import os
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

# Function to extract gesture label from the filename
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    parts = filename.split('_')
    for part in parts:
        if 'gesture' in part:
            return int(part.replace('gesture', ''))
    return None

# Function to load and process a single .mat file
def load_and_process_file(filepath, target_length):
    mat_data = sio.loadmat(filepath)
    data_emg = mat_data['data_emg']
    
    # Handle NaNs and Infinities
    data_emg = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Pad or truncate the data to the target length
    if data_emg.shape[0] < target_length:
        padded_data = np.pad(data_emg, ((0, target_length - data_emg.shape[0]), (0, 0)), 'constant', constant_values=0)
    else:
        padded_data = data_emg[:target_length]
    
    # Flatten the data for SVM input
    return padded_data.flatten()

# Process a list of files and return processed data and labels
def process_files(file_list, target_length):
    X = []
    y = []
    
    for filepath in file_list:
        X.append(load_and_process_file(filepath, target_length))
        y.append(extract_label_from_filename(filepath))
    
    return np.array(X), np.array(y)

# Specify session folders
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']

# Target length for EMG signals (adjust based on your data)
target_length = 10240

# Collect all .mat files and their corresponding labels
all_mat_files = []
all_labels = []

for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Process training and test data
X_train, y_train = process_files(train_files, target_length)
X_test, y_test = process_files(test_files, target_length)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=100)  # Adjust the number of components as needed

# Apply PCA to the training and test sets
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# --- Initialize the SVC model ---
svc_model = SVC()

# Fit the model on the training set
svc_model.fit(X_train_pca, y_train)

# Predict on the test set
y_pred = svc_model.predict(X_test_pca)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {accuracy:.4f}")


Test accuracy: 0.4065


In [None]:
import os
import numpy as np
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, GridSearchCV
from tqdm import tqdm

# Function to extract label from filename
def extract_label_from_filename(filename):
    basename = os.path.basename(filename)
    return int(basename.split('_')[2].replace('gesture', ''))

# Function to process a batch of files
def process_batch(batch_files, target_length):
    data_emg_list = []
    labels_list = []
    for mat_file in tqdm(batch_files, desc="Processing Batch"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(mat_file)
            data_emg = mat_data['data_emg']
            label = extract_label_from_filename(mat_file)

            # Handle NaNs and Infinities
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Pad or truncate the data to the target length
            if data.shape[0] < target_length:
                padded_data = np.pad(data, ((0, target_length - data.shape[0]), (0, 0)), 'constant', constant_values=0)
            else:
                padded_data = data[:target_length]
            
            # Append to lists
            data_emg_list.append(padded_data)
            labels_list.append(label)
    
    # Convert to numpy arrays
    X_batch = np.array(data_emg_list)
    y_batch = np.array(labels_list)
    
    # Reshape the data
    X_batch_flat = X_batch.reshape(X_batch.shape[0], -1)
    
    return X_batch_flat, y_batch

# Path to the data
target_length = 10240  # Desired length for each EMG signal
session_folders = ['Session1_matfiles', 'Session2_matfiles', 'Session3_matfiles']
all_mat_files = []
all_labels = []

# Collect all .mat files and their corresponding labels
for session_folder in session_folders:
    mat_files = sorted(os.listdir(session_folder))
    all_mat_files += [os.path.join(session_folder, mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]
    all_labels += [extract_label_from_filename(mat_file) for mat_file in mat_files if mat_file.endswith('.mat')]

# Split the dataset into train and test sets
train_files, test_files, train_labels, test_labels = train_test_split(all_mat_files, all_labels, test_size=0.2, random_state=42)

# Initialize PCA for dimensionality reduction
pca = PCA(n_components=4)  # Reduce to 100 components, adjust as needed

# Initialize an SVM with stochastic gradient descent (SGD)
svm_model = SGDClassifier(loss='hinge', penalty='l2', max_iter=1000)

# Process training data in batches
batch_size = 50

for i in range(0, len(train_files), batch_size):
    batch_files = train_files[i:i+batch_size]
    X_batch, y_batch = process_batch(batch_files, target_length)
    
    # Apply PCA
    X_batch_pca = pca.fit_transform(X_batch)
    
    # Incrementally fit the SVM model, passing the complete set of classes every time
    svm_model.partial_fit(X_batch_pca, y_batch, classes=np.unique(train_labels))

# Process and test the model on test data
X_test, y_test = process_batch(test_files, target_length)
X_test_pca = pca.transform(X_test)

# Inference on test data
y_pred = svm_model.predict(X_test_pca)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# --- Hyperparameter Tuning with GridSearchCV ---

# Define the parameter grid for tuning
param_grid = {
    'alpha': [0.0001, 0.001, 0.01],
    'loss': ['hinge', 'log'],
    'penalty': ['l2', 'l1', 'elasticnet'],
    'max_iter': [1000, 2000, 3000]
}

# Initialize GridSearchCV with SGDClassifier
grid_search = GridSearchCV(SGDClassifier(), param_grid, cv=5, n_jobs=-1, verbose=1)

# Fit grid search
grid_search.fit(X_test_pca, y_test)

# Print best parameters and score
print(f"Best parameters found: {grid_search.best_params_}")
print(f"Best cross-validation accuracy: {grid_search.best_score_ * 100:.2f}%")


Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.05it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.33it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 73.20it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.17it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.18it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 70.02it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.53it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 68.70it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 72.16it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.64it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 69.94it/s]
Processing Batch: 100%|██████████| 50/50 [00:00<00:00, 71.37it/s]
Processing

Test Accuracy: 5.94%
Fitting 5 folds for each of 54 candidates, totalling 270 fits
Best parameters found: {'alpha': 0.01, 'loss': 'hinge', 'max_iter': 3000, 'penalty': 'l2'}
Best cross-validation accuracy: 11.03%


135 fits failed out of a total of 270.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
7 fits failed with the following error:
Traceback (most recent call last):
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 1466, in wrapper
    estimator._validate_params()
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packages/sklearn/base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "/data1/George Pap/grabmyo/ptixiaki/.venv/lib/python3.10/site-packag

In [None]:
import os

def extract_label_from_filename(filepath):
    # Get the base name of the file (e.g., 'session1_participant1_gesture10_trial1.dat')
    filename = os.path.basename(filepath)
    
    # Split the filename by underscores
    parts = filename.split('_')
    
    # Find the part that contains 'gesture' and extract the gesture number
    for part in parts:
        if 'gesture' in part:
            # Assuming the format is 'gesture<number>', we can remove 'gesture' and get the number
            gesture_number = part.replace('gesture', '')
            return int(gesture_number)  # Convert to integer if needed
    
    return None  # If no gesture label found


In [None]:
import os
import scipy.io as sio
import numpy as np
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm  # Import tqdm for progress bars

# Path where the .mat files are stored
#data_folder = 'Session1'

# Initialize lists to hold data
data_emg_list = []
labels_list = []  # Assuming you have labels corresponding to the gestures

# Loop through each .mat file in the folder with a progress bar
for session_num in range(1, 3):  # Looping through sessions (1 to 3)
    session_folder = os.path.join(f'Session{session_num}_matfiles')
    
    # Get all .mat files in the current session folder
    mat_files = sorted(os.listdir(session_folder))
    
    # Use tqdm to show progress
    for mat_file in tqdm(mat_files, desc=f"Processing Session {session_num}"):
        if mat_file.endswith('.mat'):
            # Load the .mat file
            mat_data = sio.loadmat(os.path.join(session_folder, mat_file))
            
            # Extract the EMG data (you may need to adjust variable names if different)
            data_emg = mat_data['data_emg']  # This should match the name in your .mat files
            label = extract_label_from_filename(mat_file)  # You would need to define this function
            # Check for NaNs and infinities in the loaded data
            data = np.nan_to_num(data_emg, nan=0.0, posinf=0.0, neginf=0.0)
            # Append the data and label to lists
            data_emg_list.append(data)
            labels_list.append(label)


Processing Session 1: 100%|██████████| 5107/5107 [02:12<00:00, 38.46it/s]
Processing Session 2: 100%|██████████| 5109/5109 [02:24<00:00, 35.36it/s]


In [None]:
target_length = 10240  # The desired length for all samples
padded_data_emg_list = [
    np.pad(emg_data, ((0, target_length - emg_data.shape[0]), (0, 0)), 'constant', constant_values=0)
    if emg_data.shape[0] < target_length else emg_data
    for emg_data in data_emg_list
]

In [None]:

# Convert lists to numpy arrays
X = np.array(padded_data_emg_list)  # Data (e.g., EMG signals)
y = np.array(labels_list)  # Labels (e.g., gesture IDs)

# Flatten the data if needed (depends on the structure of your data)
X_flat = X.reshape(X.shape[0], -1)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_flat, y, test_size=0.2, random_state=42)

# Apply PCA
print("Applying PCA...")
pca = PCA(n_components=50)  # Adjust the number of components as needed
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# Train an SVM model with progress tracking
print("Training SVM...")
svm_model = SVC(kernel='linear')

# You can wrap the fitting process with tqdm to monitor the training as well
svm_model.fit(X_train_pca, y_train)

# Make predictions and evaluate the model
y_pred = svm_model.predict(X_test_pca)
accuracy = accuracy_score(y_test, y_pred)
print(f'PCA + SVM Accuracy: {accuracy:.4f}')

# Optionally: Use X_train_pca as input to a Vision Transformer model
# You can now use the X_train_pca for further processing or to feed into a Vision Transformer model.


Applying PCA...


In [None]:
# nan_count = np.isnan(data_emg).sum()
# inf_count = np.isinf(data_emg).sum()
# print(f"NaNs found: {nan_count}, Infinities found: {inf_count}")