👨‍💻 Copyright (C) $2024 Omer Tariq KAIST. - All Rights Reserved

Paper: EmoHEAL: A Fusion-Based Framework for Emotion Recognition Using Wearable Sensors

Authors: Yookyung Oh, Omer Tariq

In [3]:
import pickle
import numpy as np
import scipy.interpolate as intp

'''Notice: Place K-Emocon dataset in ./dataset folder'''
PATH = '../dataset/K-EmoCon.CS592.pkl'


DATASET = pickle.load(open(PATH, mode='rb'))

X = DATASET['X']
y = DATASET['y']
pids = DATASET['pids']
baseline = DATASET['baseline']
label_desc = DATASET['labels']

'''Use binarized external arousal for labeling'''
y_ea = np.ravel(
    y[:, label_desc == 'external.arousal']
)

def binarize_label(x: np.ndarray) -> np.ndarray:
    return np.where(x <= 2, 0, 1)

y_simple = binarize_label(x=y_ea)
    
'''Resampling'''
def resampling(x: np.ndarray, target_N: int) -> np.ndarray:
    N = np.arange(0, len(x))

    f = intp.interp1d(N, x, kind='linear')

    N_new = np.linspace(0, len(x) - 1, target_N)

    return f(N_new)

X_res = []

for x, pid in zip(X, pids):
    row = []

    acc_mag = np.sqrt(x['e4.acc.x']**2 + x['e4.acc.y']**2 + x['e4.acc.z']**2)
    
    for dtype in ['e4.hr', 'e4.eda', 'e4.temp', 'e4.bvp']:
        sig = resampling(x = x[dtype], target_N=20)
        row.append(sig)
        
    sig_acc_mag = resampling(x = acc_mag, target_N=20)
    row.append(sig_acc_mag)
    
    X_res.append(np.column_stack(row))

X_res = np.asarray(X_res)

PID = 4: (array([0, 1]), array([56, 64]))
PID = 5: (array([0, 1]), array([ 10, 110]))
PID = 8: (array([0, 1]), array([86, 34]))
PID = 9: (array([0, 1]), array([94, 26]))
PID = 10: (array([0, 1]), array([85, 35]))
PID = 13: (array([0, 1]), array([70, 50]))
PID = 15: (array([0, 1]), array([46, 74]))
PID = 19: (array([0, 1]), array([  8, 112]))
PID = 21: (array([0, 1]), array([93, 26]))
PID = 22: (array([0, 1]), array([102,  18]))
PID = 23: (array([0, 1]), array([76, 44]))
PID = 25: (array([0, 1]), array([  1, 119]))
PID = 26: (array([0, 1]), array([  3, 117]))
PID = 29: (array([0, 1]), array([40, 80]))
PID = 30: (array([0, 1]), array([ 20, 100]))
PID = 31: (array([1]), array([120]))


In [4]:
GROUP_K_FOLD = [[4,5,8,9], [10,13,15,25], [19,21,23,30], [22,26,29,31]]

I_TRAINS, I_TESTS = [], []

for g in GROUP_K_FOLD:
    M_test = np.isin(pids, g)
    M_train = ~M_test

    I_test = np.flatnonzero(M_test)
    I_train = np.flatnonzero(M_train)

    I_TRAINS.append(I_train)
    I_TESTS.append(I_test)

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.utils.tensorboard import SummaryWriter
import torchprofile

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
# Data Augmentation Function
def augment_data(X):
    augmented_data = []
    for sample in X:
        noise = np.random.normal(0, 0.01, sample.shape)
        augmented_data.append(sample + noise)
    return np.array(augmented_data)

## Architecture 1: TCN+CA-SA+GRU Fusion

In [9]:
# Channel Attention
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)

        self.fc = nn.Sequential(
            nn.Conv1d(in_planes, in_planes // ratio, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_planes // ratio, in_planes, kernel_size=1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

# Spatial Attention
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv1d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

# TCN Block Definition with Channel and Spatial Attention
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()

        self.ca = ChannelAttention(n_outputs)
        self.sa = SpatialAttention()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        out = self.ca(out) * out
        out = self.sa(out) * out
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Define the TCN block for each sensor
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Define the fusion and GRU model
class FusionGRUModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes=2):
        super(FusionGRUModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])

        tcn_output_dim = tcn_channels[-1]  # The output dimension of the last TCN layer
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, tcn_output_dim * input_dim)

        self.gru = nn.GRU(tcn_output_dim * input_dim, gru_hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(gru_hidden_dim * 2, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined, _ = self.gru(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the GRU

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionGRUModel(input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
gru_hidden_dim = 64
num_layers = 2
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")




Epoch 1/150, Loss: 0.6679, Accuracy: 0.6108, Val Loss: 0.6433, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6637, Accuracy: 0.6108, Val Loss: 0.6531, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6363, Accuracy: 0.6212, Val Loss: 0.5604, Val Accuracy: 0.6910
Epoch 4/150, Loss: 0.5303, Accuracy: 0.7163, Val Loss: 0.5639, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.4961, Accuracy: 0.7550, Val Loss: 0.5384, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.4854, Accuracy: 0.7641, Val Loss: 0.5729, Val Accuracy: 0.6910
Epoch 7/150, Loss: 0.4867, Accuracy: 0.7602, Val Loss: 0.5424, Val Accuracy: 0.7118
Epoch 8/150, Loss: 0.4744, Accuracy: 0.7789, Val Loss: 0.5250, Val Accuracy: 0.7188
Epoch 9/150, Loss: 0.4710, Accuracy: 0.7698, Val Loss: 0.5311, Val Accuracy: 0.7083
Epoch 10/150, Loss: 0.4840, Accuracy: 0.7628, Val Loss: 0.5471, Val Accuracy: 0.6944
Epoch 11/150, Loss: 0.4712, Accuracy: 0.7728, Val Loss: 0.4972, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4605, Accuracy: 0.7845, Val Loss: 0.5101, Val Accura



Epoch 1/150, Loss: 0.6702, Accuracy: 0.6103, Val Loss: 0.6431, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6636, Accuracy: 0.6108, Val Loss: 0.6309, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6520, Accuracy: 0.6108, Val Loss: 0.5896, Val Accuracy: 0.6701
Epoch 4/150, Loss: 0.5354, Accuracy: 0.7198, Val Loss: 0.5177, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.4898, Accuracy: 0.7628, Val Loss: 0.5923, Val Accuracy: 0.6910
Epoch 6/150, Loss: 0.4834, Accuracy: 0.7654, Val Loss: 0.5039, Val Accuracy: 0.7500
Epoch 7/150, Loss: 0.4806, Accuracy: 0.7715, Val Loss: 0.5294, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.4716, Accuracy: 0.7767, Val Loss: 0.5004, Val Accuracy: 0.7326
Epoch 9/150, Loss: 0.4745, Accuracy: 0.7763, Val Loss: 0.5008, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.4704, Accuracy: 0.7763, Val Loss: 0.5533, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4693, Accuracy: 0.7815, Val Loss: 0.4971, Val Accuracy: 0.7292
Epoch 12/150, Loss: 0.4661, Accuracy: 0.7832, Val Loss: 0.5203, Val Accura



Epoch 1/150, Loss: 0.6707, Accuracy: 0.6021, Val Loss: 0.6411, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6665, Accuracy: 0.6108, Val Loss: 0.6382, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6514, Accuracy: 0.6203, Val Loss: 0.6277, Val Accuracy: 0.7257
Epoch 4/150, Loss: 0.5437, Accuracy: 0.7155, Val Loss: 0.5665, Val Accuracy: 0.7049
Epoch 5/150, Loss: 0.4958, Accuracy: 0.7619, Val Loss: 0.5187, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4929, Accuracy: 0.7654, Val Loss: 0.5192, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4837, Accuracy: 0.7654, Val Loss: 0.5306, Val Accuracy: 0.7153
Epoch 8/150, Loss: 0.4765, Accuracy: 0.7724, Val Loss: 0.5004, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4731, Accuracy: 0.7754, Val Loss: 0.5006, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.4735, Accuracy: 0.7785, Val Loss: 0.5275, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4704, Accuracy: 0.7754, Val Loss: 0.5078, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4725, Accuracy: 0.7758, Val Loss: 0.4992, Val Accura



Epoch 1/150, Loss: 0.6677, Accuracy: 0.6108, Val Loss: 0.6336, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6643, Accuracy: 0.6108, Val Loss: 0.6324, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6147, Accuracy: 0.6447, Val Loss: 0.5566, Val Accuracy: 0.6840
Epoch 4/150, Loss: 0.5229, Accuracy: 0.7285, Val Loss: 0.5160, Val Accuracy: 0.7153
Epoch 5/150, Loss: 0.4925, Accuracy: 0.7624, Val Loss: 0.5168, Val Accuracy: 0.7083
Epoch 6/150, Loss: 0.4921, Accuracy: 0.7615, Val Loss: 0.5717, Val Accuracy: 0.6875
Epoch 7/150, Loss: 0.4786, Accuracy: 0.7758, Val Loss: 0.5096, Val Accuracy: 0.7049
Epoch 8/150, Loss: 0.4813, Accuracy: 0.7663, Val Loss: 0.5151, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4806, Accuracy: 0.7732, Val Loss: 0.5154, Val Accuracy: 0.7188
Epoch 10/150, Loss: 0.4698, Accuracy: 0.7763, Val Loss: 0.4986, Val Accuracy: 0.7361
Epoch 11/150, Loss: 0.4667, Accuracy: 0.7750, Val Loss: 0.4996, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.4720, Accuracy: 0.7754, Val Loss: 0.4968, Val Accura



FLOPs: 278631
Fold 1, Model Loss: 0.4433
Epoch 1/150, Loss: 0.6671, Accuracy: 0.6056, Val Loss: 0.6459, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6636, Accuracy: 0.6108, Val Loss: 0.6373, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6481, Accuracy: 0.6182, Val Loss: 0.6033, Val Accuracy: 0.6910
Epoch 4/150, Loss: 0.5786, Accuracy: 0.6851, Val Loss: 0.5688, Val Accuracy: 0.6806
Epoch 5/150, Loss: 0.5025, Accuracy: 0.7498, Val Loss: 0.5155, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.4816, Accuracy: 0.7702, Val Loss: 0.5247, Val Accuracy: 0.7153
Epoch 7/150, Loss: 0.4779, Accuracy: 0.7676, Val Loss: 0.5149, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4803, Accuracy: 0.7654, Val Loss: 0.5072, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.4697, Accuracy: 0.7767, Val Loss: 0.5120, Val Accuracy: 0.7326
Epoch 10/150, Loss: 0.4748, Accuracy: 0.7685, Val Loss: 0.5289, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.4750, Accuracy: 0.7715, Val Loss: 0.5028, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.4674, Accur



Epoch 1/150, Loss: 0.6769, Accuracy: 0.5838, Val Loss: 0.6765, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6662, Accuracy: 0.5908, Val Loss: 0.6668, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6479, Accuracy: 0.6186, Val Loss: 0.6232, Val Accuracy: 0.7083
Epoch 4/150, Loss: 0.5803, Accuracy: 0.6929, Val Loss: 0.5480, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.5302, Accuracy: 0.7311, Val Loss: 0.5173, Val Accuracy: 0.7396
Epoch 6/150, Loss: 0.5301, Accuracy: 0.7381, Val Loss: 0.5097, Val Accuracy: 0.7604
Epoch 7/150, Loss: 0.5138, Accuracy: 0.7498, Val Loss: 0.5032, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.5076, Accuracy: 0.7637, Val Loss: 0.4933, Val Accuracy: 0.7500
Epoch 9/150, Loss: 0.5112, Accuracy: 0.7580, Val Loss: 0.4991, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.5088, Accuracy: 0.7619, Val Loss: 0.4883, Val Accuracy: 0.7569
Epoch 11/150, Loss: 0.5051, Accuracy: 0.7624, Val Loss: 0.4926, Val Accuracy: 0.7708
Epoch 12/150, Loss: 0.5017, Accuracy: 0.7576, Val Loss: 0.4788, Val Accura



Epoch 1/150, Loss: 0.6741, Accuracy: 0.5834, Val Loss: 0.6710, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6662, Accuracy: 0.5934, Val Loss: 0.6670, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6424, Accuracy: 0.6056, Val Loss: 0.6050, Val Accuracy: 0.7222
Epoch 4/150, Loss: 0.5705, Accuracy: 0.7076, Val Loss: 0.5480, Val Accuracy: 0.7361
Epoch 5/150, Loss: 0.5493, Accuracy: 0.7120, Val Loss: 0.5261, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.5203, Accuracy: 0.7428, Val Loss: 0.5253, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.5149, Accuracy: 0.7424, Val Loss: 0.4992, Val Accuracy: 0.7778
Epoch 8/150, Loss: 0.5143, Accuracy: 0.7459, Val Loss: 0.5015, Val Accuracy: 0.7743
Epoch 9/150, Loss: 0.5041, Accuracy: 0.7550, Val Loss: 0.4976, Val Accuracy: 0.7535
Epoch 10/150, Loss: 0.5088, Accuracy: 0.7511, Val Loss: 0.4970, Val Accuracy: 0.7569
Epoch 11/150, Loss: 0.5010, Accuracy: 0.7663, Val Loss: 0.4890, Val Accuracy: 0.7708
Epoch 12/150, Loss: 0.4997, Accuracy: 0.7689, Val Loss: 0.4921, Val Accura



Epoch 1/150, Loss: 0.6709, Accuracy: 0.5908, Val Loss: 0.6751, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6643, Accuracy: 0.5917, Val Loss: 0.6727, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6490, Accuracy: 0.6034, Val Loss: 0.6112, Val Accuracy: 0.6493
Epoch 4/150, Loss: 0.5599, Accuracy: 0.7046, Val Loss: 0.5559, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.5401, Accuracy: 0.7246, Val Loss: 0.5106, Val Accuracy: 0.7639
Epoch 6/150, Loss: 0.5300, Accuracy: 0.7315, Val Loss: 0.5241, Val Accuracy: 0.7396
Epoch 7/150, Loss: 0.5124, Accuracy: 0.7485, Val Loss: 0.5119, Val Accuracy: 0.7708
Epoch 8/150, Loss: 0.5098, Accuracy: 0.7576, Val Loss: 0.5084, Val Accuracy: 0.7639
Epoch 9/150, Loss: 0.5056, Accuracy: 0.7593, Val Loss: 0.4930, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.5043, Accuracy: 0.7628, Val Loss: 0.5126, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.5052, Accuracy: 0.7567, Val Loss: 0.5064, Val Accuracy: 0.7639
Epoch 12/150, Loss: 0.5034, Accuracy: 0.7628, Val Loss: 0.4963, Val Accura



Epoch 1/150, Loss: 0.6735, Accuracy: 0.5899, Val Loss: 0.6778, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6644, Accuracy: 0.5908, Val Loss: 0.6624, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6511, Accuracy: 0.6056, Val Loss: 0.6277, Val Accuracy: 0.7118
Epoch 4/150, Loss: 0.5751, Accuracy: 0.6898, Val Loss: 0.5752, Val Accuracy: 0.6875
Epoch 5/150, Loss: 0.5361, Accuracy: 0.7263, Val Loss: 0.5349, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.5249, Accuracy: 0.7363, Val Loss: 0.5030, Val Accuracy: 0.7569
Epoch 7/150, Loss: 0.5131, Accuracy: 0.7537, Val Loss: 0.5040, Val Accuracy: 0.7639
Epoch 8/150, Loss: 0.5104, Accuracy: 0.7472, Val Loss: 0.4995, Val Accuracy: 0.7674
Epoch 9/150, Loss: 0.5056, Accuracy: 0.7606, Val Loss: 0.4958, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.4994, Accuracy: 0.7602, Val Loss: 0.4909, Val Accuracy: 0.7812
Epoch 11/150, Loss: 0.5023, Accuracy: 0.7589, Val Loss: 0.4991, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.5032, Accuracy: 0.7593, Val Loss: 0.4972, Val Accura



FLOPs: 278631
Fold 2, Model Loss: 0.4548
Epoch 1/150, Loss: 0.6752, Accuracy: 0.5917, Val Loss: 0.6755, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6665, Accuracy: 0.5899, Val Loss: 0.6597, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6392, Accuracy: 0.6190, Val Loss: 0.6224, Val Accuracy: 0.6875
Epoch 4/150, Loss: 0.5712, Accuracy: 0.6924, Val Loss: 0.5434, Val Accuracy: 0.7153
Epoch 5/150, Loss: 0.5280, Accuracy: 0.7359, Val Loss: 0.4920, Val Accuracy: 0.7778
Epoch 6/150, Loss: 0.5213, Accuracy: 0.7407, Val Loss: 0.5136, Val Accuracy: 0.7431
Epoch 7/150, Loss: 0.5105, Accuracy: 0.7467, Val Loss: 0.5039, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.5135, Accuracy: 0.7502, Val Loss: 0.5013, Val Accuracy: 0.7569
Epoch 9/150, Loss: 0.5064, Accuracy: 0.7637, Val Loss: 0.5137, Val Accuracy: 0.7326
Epoch 10/150, Loss: 0.5001, Accuracy: 0.7641, Val Loss: 0.4979, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.5020, Accuracy: 0.7624, Val Loss: 0.4954, Val Accuracy: 0.7743
Epoch 12/150, Loss: 0.4952, Accur



Epoch 1/150, Loss: 0.6785, Accuracy: 0.5846, Val Loss: 0.6923, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6719, Accuracy: 0.5955, Val Loss: 0.7320, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6709, Accuracy: 0.5955, Val Loss: 0.6966, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6692, Accuracy: 0.5955, Val Loss: 0.6886, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.6505, Accuracy: 0.6029, Val Loss: 0.6407, Val Accuracy: 0.5938
Epoch 6/150, Loss: 0.5370, Accuracy: 0.7010, Val Loss: 0.5061, Val Accuracy: 0.7083
Epoch 7/150, Loss: 0.5178, Accuracy: 0.7292, Val Loss: 0.5424, Val Accuracy: 0.6562
Epoch 8/150, Loss: 0.5127, Accuracy: 0.7209, Val Loss: 0.4953, Val Accuracy: 0.7465
Epoch 9/150, Loss: 0.4977, Accuracy: 0.7361, Val Loss: 0.5394, Val Accuracy: 0.7188
Epoch 10/150, Loss: 0.5043, Accuracy: 0.7352, Val Loss: 0.4990, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4990, Accuracy: 0.7348, Val Loss: 0.4869, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.4880, Accuracy: 0.7413, Val Loss: 0.5148, Val Accura



Epoch 1/150, Loss: 0.6743, Accuracy: 0.5942, Val Loss: 0.6969, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6741, Accuracy: 0.5955, Val Loss: 0.6930, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6714, Accuracy: 0.5955, Val Loss: 0.6919, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6500, Accuracy: 0.6020, Val Loss: 0.5726, Val Accuracy: 0.7118
Epoch 5/150, Loss: 0.5318, Accuracy: 0.7079, Val Loss: 0.5104, Val Accuracy: 0.7500
Epoch 6/150, Loss: 0.5055, Accuracy: 0.7405, Val Loss: 0.5051, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5065, Accuracy: 0.7361, Val Loss: 0.4896, Val Accuracy: 0.7674
Epoch 8/150, Loss: 0.5013, Accuracy: 0.7391, Val Loss: 0.5152, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4955, Accuracy: 0.7483, Val Loss: 0.4915, Val Accuracy: 0.7639
Epoch 10/150, Loss: 0.4969, Accuracy: 0.7483, Val Loss: 0.4999, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4950, Accuracy: 0.7413, Val Loss: 0.4974, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4937, Accuracy: 0.7461, Val Loss: 0.4936, Val Accura



Epoch 1/150, Loss: 0.6764, Accuracy: 0.5955, Val Loss: 0.6948, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6727, Accuracy: 0.5955, Val Loss: 0.6913, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6731, Accuracy: 0.5955, Val Loss: 0.6950, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6684, Accuracy: 0.5955, Val Loss: 0.6838, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5963, Accuracy: 0.6523, Val Loss: 0.4973, Val Accuracy: 0.7674
Epoch 6/150, Loss: 0.5160, Accuracy: 0.7240, Val Loss: 0.4990, Val Accuracy: 0.7431
Epoch 7/150, Loss: 0.5162, Accuracy: 0.7253, Val Loss: 0.5087, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.5089, Accuracy: 0.7405, Val Loss: 0.4964, Val Accuracy: 0.7674
Epoch 9/150, Loss: 0.4948, Accuracy: 0.7452, Val Loss: 0.4971, Val Accuracy: 0.7535
Epoch 10/150, Loss: 0.5000, Accuracy: 0.7352, Val Loss: 0.5045, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.4941, Accuracy: 0.7439, Val Loss: 0.5016, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4901, Accuracy: 0.7470, Val Loss: 0.5013, Val Accura



FLOPs: 278631
Fold 3, Model Loss: 0.4577
Epoch 1/150, Loss: 0.6765, Accuracy: 0.5881, Val Loss: 0.6991, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6740, Accuracy: 0.5955, Val Loss: 0.6903, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6721, Accuracy: 0.5955, Val Loss: 0.6959, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6503, Accuracy: 0.5994, Val Loss: 0.5828, Val Accuracy: 0.6389
Epoch 5/150, Loss: 0.5221, Accuracy: 0.7105, Val Loss: 0.4818, Val Accuracy: 0.7604
Epoch 6/150, Loss: 0.5037, Accuracy: 0.7370, Val Loss: 0.5061, Val Accuracy: 0.7188
Epoch 7/150, Loss: 0.4948, Accuracy: 0.7418, Val Loss: 0.4918, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.4969, Accuracy: 0.7418, Val Loss: 0.4947, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.5006, Accuracy: 0.7418, Val Loss: 0.4899, Val Accuracy: 0.7569
Epoch 10/150, Loss: 0.4953, Accuracy: 0.7465, Val Loss: 0.5162, Val Accuracy: 0.7188
Epoch 11/150, Loss: 0.4935, Accuracy: 0.7483, Val Loss: 0.4841, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4918, Accur



Epoch 1/150, Loss: 0.6755, Accuracy: 0.5946, Val Loss: 0.6909, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6719, Accuracy: 0.5955, Val Loss: 0.6932, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6736, Accuracy: 0.5955, Val Loss: 0.6903, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6646, Accuracy: 0.5955, Val Loss: 0.6714, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5733, Accuracy: 0.6684, Val Loss: 0.5155, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.5039, Accuracy: 0.7474, Val Loss: 0.5014, Val Accuracy: 0.7361
Epoch 7/150, Loss: 0.4987, Accuracy: 0.7378, Val Loss: 0.4900, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.5004, Accuracy: 0.7457, Val Loss: 0.4975, Val Accuracy: 0.7500
Epoch 9/150, Loss: 0.4934, Accuracy: 0.7474, Val Loss: 0.4952, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.4924, Accuracy: 0.7474, Val Loss: 0.4871, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4918, Accuracy: 0.7522, Val Loss: 0.4853, Val Accuracy: 0.7708
Epoch 12/150, Loss: 0.4844, Accuracy: 0.7500, Val Loss: 0.4912, Val Accura



Epoch 1/150, Loss: 0.6838, Accuracy: 0.5604, Val Loss: 0.6869, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6809, Accuracy: 0.5591, Val Loss: 0.7003, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6771, Accuracy: 0.5630, Val Loss: 0.6772, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6292, Accuracy: 0.6116, Val Loss: 0.5765, Val Accuracy: 0.6667
Epoch 5/150, Loss: 0.5857, Accuracy: 0.6633, Val Loss: 0.5798, Val Accuracy: 0.6771
Epoch 6/150, Loss: 0.5796, Accuracy: 0.6633, Val Loss: 0.5834, Val Accuracy: 0.6840
Epoch 7/150, Loss: 0.5700, Accuracy: 0.6764, Val Loss: 0.5924, Val Accuracy: 0.6875
Epoch 8/150, Loss: 0.5740, Accuracy: 0.6785, Val Loss: 0.5833, Val Accuracy: 0.6806
Epoch 9/150, Loss: 0.5652, Accuracy: 0.6933, Val Loss: 0.5808, Val Accuracy: 0.6979
Epoch 10/150, Loss: 0.5650, Accuracy: 0.6807, Val Loss: 0.5999, Val Accuracy: 0.6354
Epoch 11/150, Loss: 0.5617, Accuracy: 0.6898, Val Loss: 0.5781, Val Accuracy: 0.6875
Epoch 12/150, Loss: 0.5598, Accuracy: 0.6898, Val Loss: 0.5893, Val Accura



Epoch 1/150, Loss: 0.6839, Accuracy: 0.5595, Val Loss: 0.7058, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6839, Accuracy: 0.5517, Val Loss: 0.6994, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6781, Accuracy: 0.5604, Val Loss: 0.7078, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6625, Accuracy: 0.5782, Val Loss: 0.6212, Val Accuracy: 0.6042
Epoch 5/150, Loss: 0.5872, Accuracy: 0.6607, Val Loss: 0.5804, Val Accuracy: 0.6910
Epoch 6/150, Loss: 0.5743, Accuracy: 0.6790, Val Loss: 0.5769, Val Accuracy: 0.6944
Epoch 7/150, Loss: 0.5716, Accuracy: 0.6759, Val Loss: 0.5825, Val Accuracy: 0.6736
Epoch 8/150, Loss: 0.5691, Accuracy: 0.6964, Val Loss: 0.5818, Val Accuracy: 0.6806
Epoch 9/150, Loss: 0.5658, Accuracy: 0.6855, Val Loss: 0.5745, Val Accuracy: 0.6806
Epoch 10/150, Loss: 0.5682, Accuracy: 0.6890, Val Loss: 0.5772, Val Accuracy: 0.7049
Epoch 11/150, Loss: 0.5648, Accuracy: 0.6890, Val Loss: 0.5763, Val Accuracy: 0.7118
Epoch 12/150, Loss: 0.5658, Accuracy: 0.6777, Val Loss: 0.5727, Val Accura



Epoch 1/150, Loss: 0.6857, Accuracy: 0.5556, Val Loss: 0.6935, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6809, Accuracy: 0.5604, Val Loss: 0.6944, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6767, Accuracy: 0.5626, Val Loss: 0.6825, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6448, Accuracy: 0.5917, Val Loss: 0.6082, Val Accuracy: 0.6528
Epoch 5/150, Loss: 0.5832, Accuracy: 0.6620, Val Loss: 0.5810, Val Accuracy: 0.6736
Epoch 6/150, Loss: 0.5728, Accuracy: 0.6725, Val Loss: 0.5785, Val Accuracy: 0.7049
Epoch 7/150, Loss: 0.5703, Accuracy: 0.6777, Val Loss: 0.5766, Val Accuracy: 0.7014
Epoch 8/150, Loss: 0.5669, Accuracy: 0.6872, Val Loss: 0.5784, Val Accuracy: 0.6840
Epoch 9/150, Loss: 0.5635, Accuracy: 0.6911, Val Loss: 0.5760, Val Accuracy: 0.6875
Epoch 10/150, Loss: 0.5600, Accuracy: 0.6929, Val Loss: 0.5804, Val Accuracy: 0.6840
Epoch 11/150, Loss: 0.5634, Accuracy: 0.6937, Val Loss: 0.5809, Val Accuracy: 0.7014
Epoch 12/150, Loss: 0.5571, Accuracy: 0.6981, Val Loss: 0.5796, Val Accura



FLOPs: 278631
Fold 4, Model Loss: 0.5237
Epoch 1/150, Loss: 0.6856, Accuracy: 0.5534, Val Loss: 0.6960, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6815, Accuracy: 0.5604, Val Loss: 0.6875, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6806, Accuracy: 0.5604, Val Loss: 0.6875, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6482, Accuracy: 0.5877, Val Loss: 0.5931, Val Accuracy: 0.6771
Epoch 5/150, Loss: 0.5815, Accuracy: 0.6633, Val Loss: 0.5872, Val Accuracy: 0.6944
Epoch 6/150, Loss: 0.5664, Accuracy: 0.6855, Val Loss: 0.5736, Val Accuracy: 0.6979
Epoch 7/150, Loss: 0.5623, Accuracy: 0.6920, Val Loss: 0.5787, Val Accuracy: 0.6910
Epoch 8/150, Loss: 0.5648, Accuracy: 0.6855, Val Loss: 0.5768, Val Accuracy: 0.7014
Epoch 9/150, Loss: 0.5600, Accuracy: 0.7055, Val Loss: 0.5728, Val Accuracy: 0.7049
Epoch 10/150, Loss: 0.5619, Accuracy: 0.6968, Val Loss: 0.5792, Val Accuracy: 0.6736
Epoch 11/150, Loss: 0.5599, Accuracy: 0.6964, Val Loss: 0.5774, Val Accuracy: 0.6910
Epoch 12/150, Loss: 0.5581, Accur



Epoch 1/150, Loss: 0.6836, Accuracy: 0.5604, Val Loss: 0.7032, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6787, Accuracy: 0.5639, Val Loss: 0.6858, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6745, Accuracy: 0.5747, Val Loss: 0.6694, Val Accuracy: 0.6250
Epoch 4/150, Loss: 0.6275, Accuracy: 0.6160, Val Loss: 0.5847, Val Accuracy: 0.6806
Epoch 5/150, Loss: 0.5871, Accuracy: 0.6625, Val Loss: 0.5931, Val Accuracy: 0.6667
Epoch 6/150, Loss: 0.5762, Accuracy: 0.6716, Val Loss: 0.5757, Val Accuracy: 0.7049
Epoch 7/150, Loss: 0.5704, Accuracy: 0.6890, Val Loss: 0.5868, Val Accuracy: 0.6701
Epoch 8/150, Loss: 0.5681, Accuracy: 0.6881, Val Loss: 0.5742, Val Accuracy: 0.6979
Epoch 9/150, Loss: 0.5682, Accuracy: 0.6846, Val Loss: 0.5727, Val Accuracy: 0.7014
Epoch 10/150, Loss: 0.5682, Accuracy: 0.6894, Val Loss: 0.5758, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.5602, Accuracy: 0.7007, Val Loss: 0.5872, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.5561, Accuracy: 0.7063, Val Loss: 0.5907, Val Accura



Mean Accuracy: 70.82%, (SD=0.09592958150268692)
Mean Loss: 0.5899, (SD=0.1325660585196748)
Mean Precision: 0.7335, (SD=0.08157382243375308)
Mean Recall: 0.7082, (SD=0.09592958150268692)
Mean F1 Score: 0.6961, (SD=0.10327482422241009)


## Architecture 2: TCN+GRU Fusion

In [17]:
# TCN Block Definition
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Define the TCN block for each sensor
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Define the fusion and GRU model
class FusionGRUModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes=2):
        super(FusionGRUModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])

        tcn_output_dim = tcn_channels[-1]  # The output dimension of the last TCN layer
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, tcn_output_dim * input_dim)

        self.gru = nn.GRU(tcn_output_dim * input_dim, gru_hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(gru_hidden_dim * 2, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined, _ = self.gru(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the GRU

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionGRUModel(input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
gru_hidden_dim = 64
num_layers = 2
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")



Epoch 1/150, Loss: 0.6735, Accuracy: 0.6017, Val Loss: 0.6436, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6697, Accuracy: 0.6108, Val Loss: 0.6429, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6540, Accuracy: 0.6125, Val Loss: 0.5991, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.5373, Accuracy: 0.7220, Val Loss: 0.5111, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.4883, Accuracy: 0.7641, Val Loss: 0.5177, Val Accuracy: 0.7292
Epoch 6/150, Loss: 0.4868, Accuracy: 0.7641, Val Loss: 0.5076, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.4751, Accuracy: 0.7706, Val Loss: 0.5046, Val Accuracy: 0.7222
Epoch 8/150, Loss: 0.4681, Accuracy: 0.7754, Val Loss: 0.5201, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4716, Accuracy: 0.7767, Val Loss: 0.5187, Val Accuracy: 0.7257
Epoch 10/150, Loss: 0.4762, Accuracy: 0.7706, Val Loss: 0.4920, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.4676, Accuracy: 0.7737, Val Loss: 0.4956, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4622, Accuracy: 0.7811, Val Loss: 0.5026, Val Accura



Epoch 1/150, Loss: 0.6718, Accuracy: 0.6099, Val Loss: 0.6441, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6660, Accuracy: 0.6108, Val Loss: 0.6323, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6282, Accuracy: 0.6273, Val Loss: 0.5335, Val Accuracy: 0.7326
Epoch 4/150, Loss: 0.5075, Accuracy: 0.7363, Val Loss: 0.5127, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.4890, Accuracy: 0.7646, Val Loss: 0.5098, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.4820, Accuracy: 0.7589, Val Loss: 0.5048, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.4786, Accuracy: 0.7706, Val Loss: 0.5052, Val Accuracy: 0.7431
Epoch 8/150, Loss: 0.4757, Accuracy: 0.7754, Val Loss: 0.4973, Val Accuracy: 0.7396
Epoch 9/150, Loss: 0.4757, Accuracy: 0.7702, Val Loss: 0.5260, Val Accuracy: 0.7049
Epoch 10/150, Loss: 0.4761, Accuracy: 0.7711, Val Loss: 0.4942, Val Accuracy: 0.7188
Epoch 11/150, Loss: 0.4707, Accuracy: 0.7728, Val Loss: 0.4946, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.4682, Accuracy: 0.7732, Val Loss: 0.5144, Val Accura



Epoch 1/150, Loss: 0.6695, Accuracy: 0.6082, Val Loss: 0.6367, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6609, Accuracy: 0.6108, Val Loss: 0.6227, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.5802, Accuracy: 0.6846, Val Loss: 0.5160, Val Accuracy: 0.7431
Epoch 4/150, Loss: 0.4965, Accuracy: 0.7567, Val Loss: 0.5969, Val Accuracy: 0.6944
Epoch 5/150, Loss: 0.5003, Accuracy: 0.7537, Val Loss: 0.5026, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.4800, Accuracy: 0.7706, Val Loss: 0.5001, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.4768, Accuracy: 0.7685, Val Loss: 0.5243, Val Accuracy: 0.7153
Epoch 8/150, Loss: 0.4760, Accuracy: 0.7706, Val Loss: 0.5067, Val Accuracy: 0.7153
Epoch 9/150, Loss: 0.4688, Accuracy: 0.7767, Val Loss: 0.5045, Val Accuracy: 0.7326
Epoch 10/150, Loss: 0.4767, Accuracy: 0.7711, Val Loss: 0.5007, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.4763, Accuracy: 0.7711, Val Loss: 0.4930, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4714, Accuracy: 0.7732, Val Loss: 0.4990, Val Accura



Epoch 1/150, Loss: 0.6706, Accuracy: 0.6064, Val Loss: 0.6425, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6635, Accuracy: 0.6108, Val Loss: 0.6279, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6109, Accuracy: 0.6455, Val Loss: 0.5560, Val Accuracy: 0.7049
Epoch 4/150, Loss: 0.4999, Accuracy: 0.7502, Val Loss: 0.5015, Val Accuracy: 0.7083
Epoch 5/150, Loss: 0.5020, Accuracy: 0.7528, Val Loss: 0.5461, Val Accuracy: 0.7049
Epoch 6/150, Loss: 0.4805, Accuracy: 0.7732, Val Loss: 0.5025, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.4724, Accuracy: 0.7706, Val Loss: 0.5062, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.4734, Accuracy: 0.7750, Val Loss: 0.5034, Val Accuracy: 0.7326
Epoch 9/150, Loss: 0.4769, Accuracy: 0.7698, Val Loss: 0.4954, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4714, Accuracy: 0.7745, Val Loss: 0.5022, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4717, Accuracy: 0.7763, Val Loss: 0.5162, Val Accuracy: 0.7222
Epoch 12/150, Loss: 0.4660, Accuracy: 0.7754, Val Loss: 0.4901, Val Accura



Epoch 1/150, Loss: 0.6732, Accuracy: 0.5990, Val Loss: 0.6428, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6648, Accuracy: 0.6108, Val Loss: 0.6338, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6309, Accuracy: 0.6238, Val Loss: 0.5228, Val Accuracy: 0.7153
Epoch 4/150, Loss: 0.4970, Accuracy: 0.7554, Val Loss: 0.5023, Val Accuracy: 0.7153
Epoch 5/150, Loss: 0.4862, Accuracy: 0.7563, Val Loss: 0.5140, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4731, Accuracy: 0.7672, Val Loss: 0.5047, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.4837, Accuracy: 0.7659, Val Loss: 0.4972, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4785, Accuracy: 0.7576, Val Loss: 0.4963, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4769, Accuracy: 0.7728, Val Loss: 0.4959, Val Accuracy: 0.7222
Epoch 10/150, Loss: 0.4679, Accuracy: 0.7724, Val Loss: 0.5151, Val Accuracy: 0.7361
Epoch 11/150, Loss: 0.4665, Accuracy: 0.7767, Val Loss: 0.5011, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4610, Accuracy: 0.7854, Val Loss: 0.5563, Val Accura



Epoch 1/150, Loss: 0.6723, Accuracy: 0.5904, Val Loss: 0.6742, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6685, Accuracy: 0.5851, Val Loss: 0.6758, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6606, Accuracy: 0.5960, Val Loss: 0.6613, Val Accuracy: 0.6806
Epoch 4/150, Loss: 0.6068, Accuracy: 0.6712, Val Loss: 0.5437, Val Accuracy: 0.6944
Epoch 5/150, Loss: 0.5364, Accuracy: 0.7324, Val Loss: 0.5036, Val Accuracy: 0.7674
Epoch 6/150, Loss: 0.5226, Accuracy: 0.7359, Val Loss: 0.5409, Val Accuracy: 0.7188
Epoch 7/150, Loss: 0.5180, Accuracy: 0.7441, Val Loss: 0.5024, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.5103, Accuracy: 0.7563, Val Loss: 0.5278, Val Accuracy: 0.7188
Epoch 9/150, Loss: 0.5094, Accuracy: 0.7559, Val Loss: 0.5095, Val Accuracy: 0.7604
Epoch 10/150, Loss: 0.5113, Accuracy: 0.7602, Val Loss: 0.5112, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.5061, Accuracy: 0.7554, Val Loss: 0.5015, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.5059, Accuracy: 0.7632, Val Loss: 0.5092, Val Accura



Epoch 1/150, Loss: 0.6764, Accuracy: 0.5904, Val Loss: 0.6797, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6679, Accuracy: 0.5908, Val Loss: 0.6777, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6493, Accuracy: 0.6086, Val Loss: 0.6135, Val Accuracy: 0.7361
Epoch 4/150, Loss: 0.5614, Accuracy: 0.7037, Val Loss: 0.5032, Val Accuracy: 0.7569
Epoch 5/150, Loss: 0.5302, Accuracy: 0.7394, Val Loss: 0.4988, Val Accuracy: 0.7917
Epoch 6/150, Loss: 0.5189, Accuracy: 0.7476, Val Loss: 0.4973, Val Accuracy: 0.7743
Epoch 7/150, Loss: 0.5280, Accuracy: 0.7381, Val Loss: 0.5100, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.5093, Accuracy: 0.7589, Val Loss: 0.4940, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.5083, Accuracy: 0.7550, Val Loss: 0.4915, Val Accuracy: 0.7778
Epoch 10/150, Loss: 0.5052, Accuracy: 0.7611, Val Loss: 0.4973, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.4992, Accuracy: 0.7706, Val Loss: 0.4899, Val Accuracy: 0.7674
Epoch 12/150, Loss: 0.5031, Accuracy: 0.7615, Val Loss: 0.5042, Val Accura



Epoch 1/150, Loss: 0.6713, Accuracy: 0.5830, Val Loss: 0.6748, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6671, Accuracy: 0.5960, Val Loss: 0.6656, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6534, Accuracy: 0.6086, Val Loss: 0.6187, Val Accuracy: 0.7014
Epoch 4/150, Loss: 0.5527, Accuracy: 0.7215, Val Loss: 0.5104, Val Accuracy: 0.7500
Epoch 5/150, Loss: 0.5204, Accuracy: 0.7415, Val Loss: 0.4925, Val Accuracy: 0.7708
Epoch 6/150, Loss: 0.5163, Accuracy: 0.7511, Val Loss: 0.4970, Val Accuracy: 0.7708
Epoch 7/150, Loss: 0.5116, Accuracy: 0.7533, Val Loss: 0.4883, Val Accuracy: 0.7778
Epoch 8/150, Loss: 0.5056, Accuracy: 0.7606, Val Loss: 0.4991, Val Accuracy: 0.7674
Epoch 9/150, Loss: 0.4996, Accuracy: 0.7598, Val Loss: 0.4963, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.5037, Accuracy: 0.7572, Val Loss: 0.5021, Val Accuracy: 0.7743
Epoch 11/150, Loss: 0.4992, Accuracy: 0.7598, Val Loss: 0.4983, Val Accuracy: 0.7674
Epoch 12/150, Loss: 0.5002, Accuracy: 0.7637, Val Loss: 0.4945, Val Accura



Epoch 1/150, Loss: 0.6735, Accuracy: 0.5847, Val Loss: 0.6776, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6652, Accuracy: 0.5917, Val Loss: 0.6615, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.5964, Accuracy: 0.6568, Val Loss: 0.5765, Val Accuracy: 0.6944
Epoch 4/150, Loss: 0.5302, Accuracy: 0.7285, Val Loss: 0.5003, Val Accuracy: 0.7604
Epoch 5/150, Loss: 0.5176, Accuracy: 0.7480, Val Loss: 0.5069, Val Accuracy: 0.7708
Epoch 6/150, Loss: 0.5059, Accuracy: 0.7593, Val Loss: 0.4904, Val Accuracy: 0.7847
Epoch 7/150, Loss: 0.5185, Accuracy: 0.7350, Val Loss: 0.4956, Val Accuracy: 0.7674
Epoch 8/150, Loss: 0.5044, Accuracy: 0.7619, Val Loss: 0.5000, Val Accuracy: 0.7326
Epoch 9/150, Loss: 0.4983, Accuracy: 0.7706, Val Loss: 0.5179, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4960, Accuracy: 0.7711, Val Loss: 0.4895, Val Accuracy: 0.7708
Epoch 11/150, Loss: 0.4910, Accuracy: 0.7767, Val Loss: 0.4836, Val Accuracy: 0.7847
Epoch 12/150, Loss: 0.4920, Accuracy: 0.7706, Val Loss: 0.4995, Val Accura



Epoch 1/150, Loss: 0.6764, Accuracy: 0.5838, Val Loss: 0.6698, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6646, Accuracy: 0.5908, Val Loss: 0.6536, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6350, Accuracy: 0.6473, Val Loss: 0.5696, Val Accuracy: 0.7188
Epoch 4/150, Loss: 0.5364, Accuracy: 0.7198, Val Loss: 0.5014, Val Accuracy: 0.7812
Epoch 5/150, Loss: 0.5162, Accuracy: 0.7493, Val Loss: 0.5087, Val Accuracy: 0.7431
Epoch 6/150, Loss: 0.5155, Accuracy: 0.7598, Val Loss: 0.5007, Val Accuracy: 0.7639
Epoch 7/150, Loss: 0.5133, Accuracy: 0.7493, Val Loss: 0.5005, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.5030, Accuracy: 0.7663, Val Loss: 0.4962, Val Accuracy: 0.7743
Epoch 9/150, Loss: 0.5035, Accuracy: 0.7663, Val Loss: 0.4994, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.4946, Accuracy: 0.7737, Val Loss: 0.5053, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4937, Accuracy: 0.7758, Val Loss: 0.4861, Val Accuracy: 0.7847
Epoch 12/150, Loss: 0.4922, Accuracy: 0.7750, Val Loss: 0.4964, Val Accura



Epoch 1/150, Loss: 0.6779, Accuracy: 0.5872, Val Loss: 0.6898, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6722, Accuracy: 0.5955, Val Loss: 0.6902, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6744, Accuracy: 0.5955, Val Loss: 0.6829, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6144, Accuracy: 0.6289, Val Loss: 0.5464, Val Accuracy: 0.6771
Epoch 5/150, Loss: 0.5157, Accuracy: 0.7240, Val Loss: 0.4878, Val Accuracy: 0.7569
Epoch 6/150, Loss: 0.5111, Accuracy: 0.7331, Val Loss: 0.4948, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5106, Accuracy: 0.7266, Val Loss: 0.4879, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.5030, Accuracy: 0.7426, Val Loss: 0.4997, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.4953, Accuracy: 0.7439, Val Loss: 0.4871, Val Accuracy: 0.7639
Epoch 10/150, Loss: 0.4948, Accuracy: 0.7522, Val Loss: 0.4894, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.4881, Accuracy: 0.7630, Val Loss: 0.5099, Val Accuracy: 0.7257
Epoch 12/150, Loss: 0.4857, Accuracy: 0.7609, Val Loss: 0.4850, Val Accura



Epoch 1/150, Loss: 0.6770, Accuracy: 0.5911, Val Loss: 0.6986, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6743, Accuracy: 0.5955, Val Loss: 0.6923, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6681, Accuracy: 0.5955, Val Loss: 0.6860, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.5459, Accuracy: 0.6866, Val Loss: 0.5119, Val Accuracy: 0.7292
Epoch 5/150, Loss: 0.5148, Accuracy: 0.7231, Val Loss: 0.4921, Val Accuracy: 0.7535
Epoch 6/150, Loss: 0.5020, Accuracy: 0.7452, Val Loss: 0.5200, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.4966, Accuracy: 0.7478, Val Loss: 0.5074, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.5016, Accuracy: 0.7352, Val Loss: 0.4927, Val Accuracy: 0.7604
Epoch 9/150, Loss: 0.4980, Accuracy: 0.7422, Val Loss: 0.5128, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.4920, Accuracy: 0.7504, Val Loss: 0.4920, Val Accuracy: 0.7431
Epoch 11/150, Loss: 0.4881, Accuracy: 0.7500, Val Loss: 0.4917, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4878, Accuracy: 0.7513, Val Loss: 0.5092, Val Accura



Epoch 1/150, Loss: 0.6775, Accuracy: 0.5933, Val Loss: 0.6891, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6743, Accuracy: 0.5955, Val Loss: 0.6925, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6718, Accuracy: 0.5955, Val Loss: 0.6895, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6160, Accuracy: 0.6350, Val Loss: 0.4860, Val Accuracy: 0.7569
Epoch 5/150, Loss: 0.5231, Accuracy: 0.7214, Val Loss: 0.4952, Val Accuracy: 0.7535
Epoch 6/150, Loss: 0.5043, Accuracy: 0.7387, Val Loss: 0.4999, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.4998, Accuracy: 0.7431, Val Loss: 0.5064, Val Accuracy: 0.7361
Epoch 8/150, Loss: 0.4969, Accuracy: 0.7487, Val Loss: 0.4890, Val Accuracy: 0.7465
Epoch 9/150, Loss: 0.4926, Accuracy: 0.7500, Val Loss: 0.4870, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.4881, Accuracy: 0.7552, Val Loss: 0.4880, Val Accuracy: 0.7431
Epoch 11/150, Loss: 0.4871, Accuracy: 0.7535, Val Loss: 0.4897, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4895, Accuracy: 0.7591, Val Loss: 0.4878, Val Accura



Epoch 1/150, Loss: 0.6767, Accuracy: 0.5933, Val Loss: 0.6927, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6735, Accuracy: 0.5955, Val Loss: 0.7125, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6722, Accuracy: 0.5955, Val Loss: 0.6942, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.5982, Accuracy: 0.6471, Val Loss: 0.5251, Val Accuracy: 0.7049
Epoch 5/150, Loss: 0.5138, Accuracy: 0.7261, Val Loss: 0.5331, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.5075, Accuracy: 0.7305, Val Loss: 0.5125, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.4980, Accuracy: 0.7452, Val Loss: 0.4943, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.4966, Accuracy: 0.7457, Val Loss: 0.5311, Val Accuracy: 0.7083
Epoch 9/150, Loss: 0.4927, Accuracy: 0.7465, Val Loss: 0.4905, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.4953, Accuracy: 0.7478, Val Loss: 0.5109, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.4955, Accuracy: 0.7504, Val Loss: 0.4867, Val Accuracy: 0.7604
Epoch 12/150, Loss: 0.4928, Accuracy: 0.7530, Val Loss: 0.5257, Val Accura



Epoch 1/150, Loss: 0.6785, Accuracy: 0.5955, Val Loss: 0.6894, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6741, Accuracy: 0.5955, Val Loss: 0.6903, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6708, Accuracy: 0.5955, Val Loss: 0.7056, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6441, Accuracy: 0.6124, Val Loss: 0.5835, Val Accuracy: 0.6944
Epoch 5/150, Loss: 0.5267, Accuracy: 0.7114, Val Loss: 0.4942, Val Accuracy: 0.7500
Epoch 6/150, Loss: 0.5154, Accuracy: 0.7188, Val Loss: 0.5018, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.5066, Accuracy: 0.7418, Val Loss: 0.5096, Val Accuracy: 0.7153
Epoch 8/150, Loss: 0.5054, Accuracy: 0.7279, Val Loss: 0.4883, Val Accuracy: 0.7778
Epoch 9/150, Loss: 0.4988, Accuracy: 0.7378, Val Loss: 0.5005, Val Accuracy: 0.7604
Epoch 10/150, Loss: 0.4996, Accuracy: 0.7370, Val Loss: 0.5202, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.4921, Accuracy: 0.7383, Val Loss: 0.4835, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4925, Accuracy: 0.7461, Val Loss: 0.4978, Val Accura



Epoch 1/150, Loss: 0.6865, Accuracy: 0.5595, Val Loss: 0.6907, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6803, Accuracy: 0.5595, Val Loss: 0.6862, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6786, Accuracy: 0.5599, Val Loss: 0.6901, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6609, Accuracy: 0.5717, Val Loss: 0.6552, Val Accuracy: 0.5174
Epoch 5/150, Loss: 0.5936, Accuracy: 0.6603, Val Loss: 0.5904, Val Accuracy: 0.6458
Epoch 6/150, Loss: 0.5720, Accuracy: 0.6785, Val Loss: 0.5766, Val Accuracy: 0.6840
Epoch 7/150, Loss: 0.5797, Accuracy: 0.6681, Val Loss: 0.5863, Val Accuracy: 0.7049
Epoch 8/150, Loss: 0.5593, Accuracy: 0.6937, Val Loss: 0.5697, Val Accuracy: 0.7118
Epoch 9/150, Loss: 0.5611, Accuracy: 0.6907, Val Loss: 0.5753, Val Accuracy: 0.6979
Epoch 10/150, Loss: 0.5591, Accuracy: 0.6933, Val Loss: 0.5782, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.5584, Accuracy: 0.6894, Val Loss: 0.5721, Val Accuracy: 0.6979
Epoch 12/150, Loss: 0.5620, Accuracy: 0.6964, Val Loss: 0.5729, Val Accura



Epoch 1/150, Loss: 0.6856, Accuracy: 0.5604, Val Loss: 0.6891, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6811, Accuracy: 0.5604, Val Loss: 0.6921, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6794, Accuracy: 0.5491, Val Loss: 0.6910, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6356, Accuracy: 0.6199, Val Loss: 0.6214, Val Accuracy: 0.6285
Epoch 5/150, Loss: 0.5714, Accuracy: 0.6690, Val Loss: 0.5832, Val Accuracy: 0.6806
Epoch 6/150, Loss: 0.5753, Accuracy: 0.6742, Val Loss: 0.5893, Val Accuracy: 0.6632
Epoch 7/150, Loss: 0.5695, Accuracy: 0.6707, Val Loss: 0.5929, Val Accuracy: 0.6806
Epoch 8/150, Loss: 0.5629, Accuracy: 0.6937, Val Loss: 0.5919, Val Accuracy: 0.6493
Epoch 9/150, Loss: 0.5635, Accuracy: 0.6894, Val Loss: 0.5770, Val Accuracy: 0.7153
Epoch 10/150, Loss: 0.5560, Accuracy: 0.6937, Val Loss: 0.5737, Val Accuracy: 0.7118
Epoch 11/150, Loss: 0.5532, Accuracy: 0.7059, Val Loss: 0.5747, Val Accuracy: 0.7153
Epoch 12/150, Loss: 0.5510, Accuracy: 0.7107, Val Loss: 0.5790, Val Accura



Epoch 1/150, Loss: 0.6838, Accuracy: 0.5613, Val Loss: 0.6894, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6788, Accuracy: 0.5595, Val Loss: 0.6938, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6746, Accuracy: 0.5669, Val Loss: 0.6633, Val Accuracy: 0.6285
Epoch 4/150, Loss: 0.6079, Accuracy: 0.6412, Val Loss: 0.5837, Val Accuracy: 0.6979
Epoch 5/150, Loss: 0.5834, Accuracy: 0.6651, Val Loss: 0.5941, Val Accuracy: 0.6840
Epoch 6/150, Loss: 0.5765, Accuracy: 0.6868, Val Loss: 0.5875, Val Accuracy: 0.6667
Epoch 7/150, Loss: 0.5666, Accuracy: 0.6885, Val Loss: 0.5762, Val Accuracy: 0.6910
Epoch 8/150, Loss: 0.5693, Accuracy: 0.6816, Val Loss: 0.5890, Val Accuracy: 0.7014
Epoch 9/150, Loss: 0.5645, Accuracy: 0.6903, Val Loss: 0.5814, Val Accuracy: 0.6910
Epoch 10/150, Loss: 0.5605, Accuracy: 0.6816, Val Loss: 0.5765, Val Accuracy: 0.6979
Epoch 11/150, Loss: 0.5514, Accuracy: 0.7142, Val Loss: 0.5705, Val Accuracy: 0.6840
Epoch 12/150, Loss: 0.5539, Accuracy: 0.6959, Val Loss: 0.6078, Val Accura



Epoch 1/150, Loss: 0.6863, Accuracy: 0.5604, Val Loss: 0.7002, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6826, Accuracy: 0.5604, Val Loss: 0.6904, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6784, Accuracy: 0.5595, Val Loss: 0.6945, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6399, Accuracy: 0.6103, Val Loss: 0.6059, Val Accuracy: 0.6354
Epoch 5/150, Loss: 0.5759, Accuracy: 0.6738, Val Loss: 0.5784, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.5707, Accuracy: 0.6851, Val Loss: 0.5921, Val Accuracy: 0.6458
Epoch 7/150, Loss: 0.5797, Accuracy: 0.6685, Val Loss: 0.5788, Val Accuracy: 0.6910
Epoch 8/150, Loss: 0.5611, Accuracy: 0.6933, Val Loss: 0.5895, Val Accuracy: 0.6771
Epoch 9/150, Loss: 0.5623, Accuracy: 0.6851, Val Loss: 0.5902, Val Accuracy: 0.6944
Epoch 10/150, Loss: 0.5633, Accuracy: 0.6964, Val Loss: 0.5834, Val Accuracy: 0.6910
Epoch 11/150, Loss: 0.5521, Accuracy: 0.7103, Val Loss: 0.5804, Val Accuracy: 0.6910
Epoch 12/150, Loss: 0.5502, Accuracy: 0.7059, Val Loss: 0.5805, Val Accura



Epoch 1/150, Loss: 0.6862, Accuracy: 0.5543, Val Loss: 0.6911, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6796, Accuracy: 0.5604, Val Loss: 0.6918, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6760, Accuracy: 0.5647, Val Loss: 0.6665, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6031, Accuracy: 0.6490, Val Loss: 0.5867, Val Accuracy: 0.6771
Epoch 5/150, Loss: 0.5787, Accuracy: 0.6681, Val Loss: 0.5812, Val Accuracy: 0.6944
Epoch 6/150, Loss: 0.5682, Accuracy: 0.6890, Val Loss: 0.5911, Val Accuracy: 0.6771
Epoch 7/150, Loss: 0.5665, Accuracy: 0.6846, Val Loss: 0.5828, Val Accuracy: 0.7014
Epoch 8/150, Loss: 0.5659, Accuracy: 0.6877, Val Loss: 0.5822, Val Accuracy: 0.6771
Epoch 9/150, Loss: 0.5630, Accuracy: 0.6968, Val Loss: 0.6209, Val Accuracy: 0.6319
Epoch 10/150, Loss: 0.5611, Accuracy: 0.6942, Val Loss: 0.5784, Val Accuracy: 0.6806
Epoch 11/150, Loss: 0.5542, Accuracy: 0.7007, Val Loss: 0.5876, Val Accuracy: 0.6806
Epoch 12/150, Loss: 0.5512, Accuracy: 0.7089, Val Loss: 0.5739, Val Accura



Mean Accuracy: 70.87%, (SD=0.10056346214837565)
Mean Loss: 0.5970, (SD=0.1281347566262459)
Mean Precision: 0.7285, (SD=0.09851215293408158)
Mean Recall: 0.7087, (SD=0.10056346214837565)
Mean F1 Score: 0.7026, (SD=0.10293642515266566)


## Architecture 3: TCN+xLSTM

In [11]:
# TCN Block Definition
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Define the TCN block for each sensor
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Define the xLSTM block
class xLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, bidirectional=True):
        super(xLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)

    def forward(self, x):
        output, (hn, cn) = self.lstm(x)
        return output

# Define the fusion and xLSTM model
class FusionxLSTMModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, lstm_hidden_dim, num_layers, num_classes=2):
        super(FusionxLSTMModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])

        tcn_output_dim = tcn_channels[-1]  # The output dimension of the last TCN layer
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, tcn_output_dim * input_dim)

        self.lstm = xLSTM(tcn_output_dim * input_dim, lstm_hidden_dim, num_layers=num_layers, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(lstm_hidden_dim * 2, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined = self.lstm(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the xLSTM

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, lstm_hidden_dim, num_layers, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionxLSTMModel(input_dim, tcn_channels, lstm_hidden_dim, num_layers, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
lstm_hidden_dim = 64
num_layers = 2
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, lstm_hidden_dim, num_layers, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")


Epoch 1/150, Loss: 0.6737, Accuracy: 0.5951, Val Loss: 0.6422, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6648, Accuracy: 0.6108, Val Loss: 0.6328, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6506, Accuracy: 0.6108, Val Loss: 0.5615, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.5172, Accuracy: 0.7398, Val Loss: 0.5025, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.4929, Accuracy: 0.7615, Val Loss: 0.5202, Val Accuracy: 0.7222
Epoch 6/150, Loss: 0.4867, Accuracy: 0.7663, Val Loss: 0.5002, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.4804, Accuracy: 0.7715, Val Loss: 0.5102, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4801, Accuracy: 0.7659, Val Loss: 0.4912, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4743, Accuracy: 0.7737, Val Loss: 0.5238, Val Accuracy: 0.7222
Epoch 10/150, Loss: 0.4738, Accuracy: 0.7780, Val Loss: 0.5212, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.4720, Accuracy: 0.7741, Val Loss: 0.5200, Val Accuracy: 0.7222
Epoch 12/150, Loss: 0.4710, Accuracy: 0.7724, Val Loss: 0.4909, Val Accura



Epoch 1/150, Loss: 0.6694, Accuracy: 0.6108, Val Loss: 0.6451, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6662, Accuracy: 0.6108, Val Loss: 0.6409, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6243, Accuracy: 0.6303, Val Loss: 0.5157, Val Accuracy: 0.7153
Epoch 4/150, Loss: 0.5167, Accuracy: 0.7389, Val Loss: 0.5180, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.4885, Accuracy: 0.7585, Val Loss: 0.5346, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.4860, Accuracy: 0.7650, Val Loss: 0.5046, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4806, Accuracy: 0.7680, Val Loss: 0.5019, Val Accuracy: 0.7222
Epoch 8/150, Loss: 0.4751, Accuracy: 0.7732, Val Loss: 0.5015, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.4711, Accuracy: 0.7750, Val Loss: 0.5267, Val Accuracy: 0.7257
Epoch 10/150, Loss: 0.4744, Accuracy: 0.7724, Val Loss: 0.5111, Val Accuracy: 0.7326
Epoch 11/150, Loss: 0.4679, Accuracy: 0.7811, Val Loss: 0.4995, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4669, Accuracy: 0.7776, Val Loss: 0.5426, Val Accura



Epoch 1/150, Loss: 0.6670, Accuracy: 0.6108, Val Loss: 0.6339, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6654, Accuracy: 0.6108, Val Loss: 0.6318, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6579, Accuracy: 0.6108, Val Loss: 0.6291, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.5496, Accuracy: 0.6942, Val Loss: 0.5023, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.4914, Accuracy: 0.7624, Val Loss: 0.5071, Val Accuracy: 0.7292
Epoch 6/150, Loss: 0.4857, Accuracy: 0.7598, Val Loss: 0.5004, Val Accuracy: 0.7396
Epoch 7/150, Loss: 0.4764, Accuracy: 0.7772, Val Loss: 0.5069, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.4813, Accuracy: 0.7680, Val Loss: 0.5031, Val Accuracy: 0.7257
Epoch 9/150, Loss: 0.4739, Accuracy: 0.7715, Val Loss: 0.5038, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4719, Accuracy: 0.7789, Val Loss: 0.4986, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4666, Accuracy: 0.7828, Val Loss: 0.5060, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4684, Accuracy: 0.7702, Val Loss: 0.5128, Val Accura



Epoch 1/150, Loss: 0.6706, Accuracy: 0.6108, Val Loss: 0.6368, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6651, Accuracy: 0.6108, Val Loss: 0.6470, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6066, Accuracy: 0.6451, Val Loss: 0.5254, Val Accuracy: 0.7118
Epoch 4/150, Loss: 0.4950, Accuracy: 0.7654, Val Loss: 0.5126, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.4867, Accuracy: 0.7598, Val Loss: 0.4981, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4894, Accuracy: 0.7706, Val Loss: 0.5089, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4792, Accuracy: 0.7659, Val Loss: 0.5254, Val Accuracy: 0.7118
Epoch 8/150, Loss: 0.4817, Accuracy: 0.7667, Val Loss: 0.5420, Val Accuracy: 0.7083
Epoch 9/150, Loss: 0.4735, Accuracy: 0.7758, Val Loss: 0.4966, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4746, Accuracy: 0.7750, Val Loss: 0.5101, Val Accuracy: 0.7326
Epoch 11/150, Loss: 0.4659, Accuracy: 0.7772, Val Loss: 0.5070, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4670, Accuracy: 0.7841, Val Loss: 0.5330, Val Accura



Epoch 1/150, Loss: 0.6709, Accuracy: 0.6108, Val Loss: 0.6383, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6654, Accuracy: 0.6108, Val Loss: 0.6514, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6466, Accuracy: 0.6112, Val Loss: 0.5779, Val Accuracy: 0.7465
Epoch 4/150, Loss: 0.5206, Accuracy: 0.7311, Val Loss: 0.5102, Val Accuracy: 0.7292
Epoch 5/150, Loss: 0.4931, Accuracy: 0.7537, Val Loss: 0.4937, Val Accuracy: 0.7431
Epoch 6/150, Loss: 0.4839, Accuracy: 0.7615, Val Loss: 0.4980, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4908, Accuracy: 0.7576, Val Loss: 0.5040, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.4781, Accuracy: 0.7698, Val Loss: 0.5130, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4778, Accuracy: 0.7676, Val Loss: 0.5024, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4768, Accuracy: 0.7702, Val Loss: 0.5262, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.4710, Accuracy: 0.7772, Val Loss: 0.5136, Val Accuracy: 0.7292
Epoch 12/150, Loss: 0.4664, Accuracy: 0.7793, Val Loss: 0.4933, Val Accura



Epoch 1/150, Loss: 0.6773, Accuracy: 0.5912, Val Loss: 0.6759, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6668, Accuracy: 0.5908, Val Loss: 0.6700, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6545, Accuracy: 0.6051, Val Loss: 0.6394, Val Accuracy: 0.6319
Epoch 4/150, Loss: 0.5647, Accuracy: 0.7050, Val Loss: 0.5214, Val Accuracy: 0.7396
Epoch 5/150, Loss: 0.5281, Accuracy: 0.7337, Val Loss: 0.4952, Val Accuracy: 0.7674
Epoch 6/150, Loss: 0.5178, Accuracy: 0.7441, Val Loss: 0.4991, Val Accuracy: 0.7639
Epoch 7/150, Loss: 0.5136, Accuracy: 0.7515, Val Loss: 0.5071, Val Accuracy: 0.7708
Epoch 8/150, Loss: 0.5146, Accuracy: 0.7437, Val Loss: 0.5055, Val Accuracy: 0.7500
Epoch 9/150, Loss: 0.5057, Accuracy: 0.7559, Val Loss: 0.5085, Val Accuracy: 0.7500
Epoch 10/150, Loss: 0.5021, Accuracy: 0.7598, Val Loss: 0.4963, Val Accuracy: 0.7812
Epoch 11/150, Loss: 0.4976, Accuracy: 0.7715, Val Loss: 0.4918, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.5016, Accuracy: 0.7615, Val Loss: 0.4876, Val Accura



Epoch 1/150, Loss: 0.6770, Accuracy: 0.5912, Val Loss: 0.6773, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6675, Accuracy: 0.5908, Val Loss: 0.6710, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6622, Accuracy: 0.5921, Val Loss: 0.6638, Val Accuracy: 0.6701
Epoch 4/150, Loss: 0.6038, Accuracy: 0.6772, Val Loss: 0.5634, Val Accuracy: 0.7118
Epoch 5/150, Loss: 0.5311, Accuracy: 0.7381, Val Loss: 0.5298, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.5217, Accuracy: 0.7420, Val Loss: 0.4966, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5135, Accuracy: 0.7537, Val Loss: 0.5156, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.5048, Accuracy: 0.7637, Val Loss: 0.4962, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.5016, Accuracy: 0.7654, Val Loss: 0.4925, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.5125, Accuracy: 0.7493, Val Loss: 0.4935, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.5001, Accuracy: 0.7606, Val Loss: 0.5024, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4894, Accuracy: 0.7737, Val Loss: 0.5000, Val Accura



Epoch 1/150, Loss: 0.6757, Accuracy: 0.5795, Val Loss: 0.6741, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6651, Accuracy: 0.5908, Val Loss: 0.6720, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6482, Accuracy: 0.6229, Val Loss: 0.6450, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.5658, Accuracy: 0.7089, Val Loss: 0.5225, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.5238, Accuracy: 0.7381, Val Loss: 0.4979, Val Accuracy: 0.7743
Epoch 6/150, Loss: 0.5142, Accuracy: 0.7498, Val Loss: 0.4899, Val Accuracy: 0.7847
Epoch 7/150, Loss: 0.5078, Accuracy: 0.7507, Val Loss: 0.5066, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.5143, Accuracy: 0.7476, Val Loss: 0.5056, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.5003, Accuracy: 0.7598, Val Loss: 0.5060, Val Accuracy: 0.7535
Epoch 10/150, Loss: 0.4960, Accuracy: 0.7680, Val Loss: 0.5003, Val Accuracy: 0.7361
Epoch 11/150, Loss: 0.5028, Accuracy: 0.7619, Val Loss: 0.4956, Val Accuracy: 0.7812
Epoch 12/150, Loss: 0.4940, Accuracy: 0.7706, Val Loss: 0.4966, Val Accura



Epoch 1/150, Loss: 0.6729, Accuracy: 0.5908, Val Loss: 0.6773, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6666, Accuracy: 0.5908, Val Loss: 0.6710, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6411, Accuracy: 0.6199, Val Loss: 0.6919, Val Accuracy: 0.5417
Epoch 4/150, Loss: 0.5482, Accuracy: 0.7124, Val Loss: 0.4863, Val Accuracy: 0.7812
Epoch 5/150, Loss: 0.5190, Accuracy: 0.7493, Val Loss: 0.4979, Val Accuracy: 0.7743
Epoch 6/150, Loss: 0.5093, Accuracy: 0.7619, Val Loss: 0.4995, Val Accuracy: 0.7674
Epoch 7/150, Loss: 0.5141, Accuracy: 0.7446, Val Loss: 0.5044, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.5143, Accuracy: 0.7507, Val Loss: 0.4914, Val Accuracy: 0.7778
Epoch 9/150, Loss: 0.5046, Accuracy: 0.7611, Val Loss: 0.5021, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.5080, Accuracy: 0.7567, Val Loss: 0.4991, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.4947, Accuracy: 0.7758, Val Loss: 0.4868, Val Accuracy: 0.7708
Epoch 12/150, Loss: 0.4879, Accuracy: 0.7780, Val Loss: 0.4884, Val Accura



Epoch 1/150, Loss: 0.6783, Accuracy: 0.5830, Val Loss: 0.6742, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6651, Accuracy: 0.5908, Val Loss: 0.6686, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6500, Accuracy: 0.6134, Val Loss: 0.6232, Val Accuracy: 0.7153
Epoch 4/150, Loss: 0.5533, Accuracy: 0.7181, Val Loss: 0.5255, Val Accuracy: 0.7396
Epoch 5/150, Loss: 0.5248, Accuracy: 0.7350, Val Loss: 0.5049, Val Accuracy: 0.7604
Epoch 6/150, Loss: 0.5132, Accuracy: 0.7563, Val Loss: 0.4951, Val Accuracy: 0.7778
Epoch 7/150, Loss: 0.5128, Accuracy: 0.7507, Val Loss: 0.4965, Val Accuracy: 0.7778
Epoch 8/150, Loss: 0.5112, Accuracy: 0.7576, Val Loss: 0.5055, Val Accuracy: 0.7500
Epoch 9/150, Loss: 0.5023, Accuracy: 0.7659, Val Loss: 0.5313, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4933, Accuracy: 0.7750, Val Loss: 0.5029, Val Accuracy: 0.7569
Epoch 11/150, Loss: 0.4994, Accuracy: 0.7546, Val Loss: 0.5063, Val Accuracy: 0.7500
Epoch 12/150, Loss: 0.4920, Accuracy: 0.7706, Val Loss: 0.4941, Val Accura



Epoch 1/150, Loss: 0.6783, Accuracy: 0.5881, Val Loss: 0.6891, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6739, Accuracy: 0.5955, Val Loss: 0.6903, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6717, Accuracy: 0.5955, Val Loss: 0.7085, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6436, Accuracy: 0.6115, Val Loss: 0.5863, Val Accuracy: 0.6146
Epoch 5/150, Loss: 0.5174, Accuracy: 0.7240, Val Loss: 0.4973, Val Accuracy: 0.7535
Epoch 6/150, Loss: 0.5008, Accuracy: 0.7418, Val Loss: 0.5193, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.5096, Accuracy: 0.7214, Val Loss: 0.5091, Val Accuracy: 0.7361
Epoch 8/150, Loss: 0.4964, Accuracy: 0.7487, Val Loss: 0.4994, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.4991, Accuracy: 0.7504, Val Loss: 0.4922, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.4903, Accuracy: 0.7461, Val Loss: 0.4879, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4878, Accuracy: 0.7513, Val Loss: 0.4947, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.4967, Accuracy: 0.7405, Val Loss: 0.5005, Val Accura



Epoch 1/150, Loss: 0.6773, Accuracy: 0.5951, Val Loss: 0.6883, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6743, Accuracy: 0.5955, Val Loss: 0.6972, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6720, Accuracy: 0.5955, Val Loss: 0.6931, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6628, Accuracy: 0.5955, Val Loss: 0.6560, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5510, Accuracy: 0.6797, Val Loss: 0.5050, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5126, Accuracy: 0.7396, Val Loss: 0.4891, Val Accuracy: 0.7604
Epoch 7/150, Loss: 0.5024, Accuracy: 0.7391, Val Loss: 0.4960, Val Accuracy: 0.7604
Epoch 8/150, Loss: 0.5016, Accuracy: 0.7418, Val Loss: 0.5032, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.5050, Accuracy: 0.7400, Val Loss: 0.4968, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.4920, Accuracy: 0.7461, Val Loss: 0.4961, Val Accuracy: 0.7569
Epoch 11/150, Loss: 0.4850, Accuracy: 0.7552, Val Loss: 0.4885, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4898, Accuracy: 0.7595, Val Loss: 0.4979, Val Accura



Epoch 1/150, Loss: 0.6780, Accuracy: 0.5855, Val Loss: 0.6993, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6749, Accuracy: 0.5955, Val Loss: 0.6905, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6711, Accuracy: 0.5955, Val Loss: 0.7037, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6546, Accuracy: 0.6003, Val Loss: 0.6425, Val Accuracy: 0.6979
Epoch 5/150, Loss: 0.5522, Accuracy: 0.6823, Val Loss: 0.4951, Val Accuracy: 0.7569
Epoch 6/150, Loss: 0.5157, Accuracy: 0.7279, Val Loss: 0.5169, Val Accuracy: 0.7118
Epoch 7/150, Loss: 0.5145, Accuracy: 0.7270, Val Loss: 0.4984, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.5009, Accuracy: 0.7448, Val Loss: 0.4988, Val Accuracy: 0.7396
Epoch 9/150, Loss: 0.5002, Accuracy: 0.7435, Val Loss: 0.4841, Val Accuracy: 0.7500
Epoch 10/150, Loss: 0.5097, Accuracy: 0.7235, Val Loss: 0.4927, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.4986, Accuracy: 0.7422, Val Loss: 0.4887, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4918, Accuracy: 0.7457, Val Loss: 0.5045, Val Accura



Epoch 1/150, Loss: 0.6796, Accuracy: 0.5846, Val Loss: 0.6921, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6732, Accuracy: 0.5955, Val Loss: 0.6923, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6696, Accuracy: 0.5955, Val Loss: 0.6905, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6167, Accuracy: 0.6198, Val Loss: 0.5565, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.5249, Accuracy: 0.7101, Val Loss: 0.5184, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.5149, Accuracy: 0.7313, Val Loss: 0.5060, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.5070, Accuracy: 0.7374, Val Loss: 0.4920, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.4968, Accuracy: 0.7487, Val Loss: 0.4912, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.5008, Accuracy: 0.7461, Val Loss: 0.4959, Val Accuracy: 0.7500
Epoch 10/150, Loss: 0.4876, Accuracy: 0.7578, Val Loss: 0.4849, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.4912, Accuracy: 0.7422, Val Loss: 0.4925, Val Accuracy: 0.7639
Epoch 12/150, Loss: 0.4912, Accuracy: 0.7465, Val Loss: 0.4913, Val Accura



Epoch 1/150, Loss: 0.6780, Accuracy: 0.5959, Val Loss: 0.7024, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6748, Accuracy: 0.5955, Val Loss: 0.6915, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6722, Accuracy: 0.5955, Val Loss: 0.6888, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6630, Accuracy: 0.5955, Val Loss: 0.6635, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5652, Accuracy: 0.6476, Val Loss: 0.5237, Val Accuracy: 0.7292
Epoch 6/150, Loss: 0.5163, Accuracy: 0.7283, Val Loss: 0.5008, Val Accuracy: 0.7465
Epoch 7/150, Loss: 0.5027, Accuracy: 0.7339, Val Loss: 0.5107, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.5025, Accuracy: 0.7339, Val Loss: 0.4984, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.4941, Accuracy: 0.7452, Val Loss: 0.5025, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.4936, Accuracy: 0.7400, Val Loss: 0.4999, Val Accuracy: 0.7326
Epoch 11/150, Loss: 0.4935, Accuracy: 0.7478, Val Loss: 0.4953, Val Accuracy: 0.7431
Epoch 12/150, Loss: 0.4882, Accuracy: 0.7530, Val Loss: 0.4862, Val Accura



Epoch 1/150, Loss: 0.6850, Accuracy: 0.5556, Val Loss: 0.6966, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6825, Accuracy: 0.5604, Val Loss: 0.6885, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6812, Accuracy: 0.5608, Val Loss: 0.6889, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6488, Accuracy: 0.5886, Val Loss: 0.5965, Val Accuracy: 0.6597
Epoch 5/150, Loss: 0.5835, Accuracy: 0.6681, Val Loss: 0.5898, Val Accuracy: 0.6840
Epoch 6/150, Loss: 0.5738, Accuracy: 0.6746, Val Loss: 0.5839, Val Accuracy: 0.6771
Epoch 7/150, Loss: 0.5715, Accuracy: 0.6855, Val Loss: 0.5911, Val Accuracy: 0.6632
Epoch 8/150, Loss: 0.5643, Accuracy: 0.6972, Val Loss: 0.5911, Val Accuracy: 0.6632
Epoch 9/150, Loss: 0.5666, Accuracy: 0.6842, Val Loss: 0.5877, Val Accuracy: 0.6875
Epoch 10/150, Loss: 0.5581, Accuracy: 0.6946, Val Loss: 0.5733, Val Accuracy: 0.7014
Epoch 11/150, Loss: 0.5485, Accuracy: 0.7129, Val Loss: 0.6328, Val Accuracy: 0.6528
Epoch 12/150, Loss: 0.5580, Accuracy: 0.7007, Val Loss: 0.5828, Val Accura



Epoch 1/150, Loss: 0.6867, Accuracy: 0.5604, Val Loss: 0.6964, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6824, Accuracy: 0.5604, Val Loss: 0.6910, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6772, Accuracy: 0.5647, Val Loss: 0.6705, Val Accuracy: 0.6042
Epoch 4/150, Loss: 0.6109, Accuracy: 0.6373, Val Loss: 0.5903, Val Accuracy: 0.6944
Epoch 5/150, Loss: 0.5720, Accuracy: 0.6859, Val Loss: 0.5813, Val Accuracy: 0.7014
Epoch 6/150, Loss: 0.5860, Accuracy: 0.6685, Val Loss: 0.5890, Val Accuracy: 0.6979
Epoch 7/150, Loss: 0.5671, Accuracy: 0.6911, Val Loss: 0.5897, Val Accuracy: 0.6771
Epoch 8/150, Loss: 0.5644, Accuracy: 0.6907, Val Loss: 0.5798, Val Accuracy: 0.6979
Epoch 9/150, Loss: 0.5609, Accuracy: 0.6890, Val Loss: 0.6271, Val Accuracy: 0.6562
Epoch 10/150, Loss: 0.5608, Accuracy: 0.6985, Val Loss: 0.5719, Val Accuracy: 0.6944
Epoch 11/150, Loss: 0.5596, Accuracy: 0.6959, Val Loss: 0.5768, Val Accuracy: 0.6979
Epoch 12/150, Loss: 0.5578, Accuracy: 0.7011, Val Loss: 0.5894, Val Accura



Epoch 1/150, Loss: 0.6897, Accuracy: 0.5413, Val Loss: 0.6924, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6832, Accuracy: 0.5604, Val Loss: 0.6924, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6719, Accuracy: 0.5604, Val Loss: 0.6667, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.5971, Accuracy: 0.6507, Val Loss: 0.5938, Val Accuracy: 0.6771
Epoch 5/150, Loss: 0.5744, Accuracy: 0.6772, Val Loss: 0.6122, Val Accuracy: 0.6493
Epoch 6/150, Loss: 0.5744, Accuracy: 0.6768, Val Loss: 0.5939, Val Accuracy: 0.6944
Epoch 7/150, Loss: 0.5650, Accuracy: 0.6994, Val Loss: 0.5847, Val Accuracy: 0.6910
Epoch 8/150, Loss: 0.5645, Accuracy: 0.6924, Val Loss: 0.5868, Val Accuracy: 0.6875
Epoch 9/150, Loss: 0.5655, Accuracy: 0.6881, Val Loss: 0.5814, Val Accuracy: 0.6771
Epoch 10/150, Loss: 0.5634, Accuracy: 0.6972, Val Loss: 0.6091, Val Accuracy: 0.6424
Epoch 11/150, Loss: 0.5601, Accuracy: 0.6942, Val Loss: 0.5890, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.5635, Accuracy: 0.6950, Val Loss: 0.5807, Val Accura



Epoch 1/150, Loss: 0.6870, Accuracy: 0.5604, Val Loss: 0.6947, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6822, Accuracy: 0.5604, Val Loss: 0.6897, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6791, Accuracy: 0.5604, Val Loss: 0.7038, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6498, Accuracy: 0.5743, Val Loss: 0.5995, Val Accuracy: 0.6562
Epoch 5/150, Loss: 0.5992, Accuracy: 0.6394, Val Loss: 0.6064, Val Accuracy: 0.6771
Epoch 6/150, Loss: 0.5815, Accuracy: 0.6742, Val Loss: 0.5920, Val Accuracy: 0.6806
Epoch 7/150, Loss: 0.5771, Accuracy: 0.6638, Val Loss: 0.5853, Val Accuracy: 0.6875
Epoch 8/150, Loss: 0.5708, Accuracy: 0.6859, Val Loss: 0.5810, Val Accuracy: 0.6910
Epoch 9/150, Loss: 0.5651, Accuracy: 0.6851, Val Loss: 0.5818, Val Accuracy: 0.7083
Epoch 10/150, Loss: 0.5601, Accuracy: 0.6859, Val Loss: 0.5794, Val Accuracy: 0.6979
Epoch 11/150, Loss: 0.5602, Accuracy: 0.7046, Val Loss: 0.5811, Val Accuracy: 0.7014
Epoch 12/150, Loss: 0.5529, Accuracy: 0.7011, Val Loss: 0.5885, Val Accura



Epoch 1/150, Loss: 0.6886, Accuracy: 0.5604, Val Loss: 0.6922, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6839, Accuracy: 0.5604, Val Loss: 0.6914, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6817, Accuracy: 0.5604, Val Loss: 0.6894, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6689, Accuracy: 0.5669, Val Loss: 0.6220, Val Accuracy: 0.6424
Epoch 5/150, Loss: 0.5882, Accuracy: 0.6646, Val Loss: 0.5864, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.5767, Accuracy: 0.6725, Val Loss: 0.6089, Val Accuracy: 0.6354
Epoch 7/150, Loss: 0.5736, Accuracy: 0.6725, Val Loss: 0.5878, Val Accuracy: 0.6840
Epoch 8/150, Loss: 0.5675, Accuracy: 0.6959, Val Loss: 0.5814, Val Accuracy: 0.6875
Epoch 9/150, Loss: 0.5675, Accuracy: 0.6907, Val Loss: 0.5817, Val Accuracy: 0.6875
Epoch 10/150, Loss: 0.5630, Accuracy: 0.6872, Val Loss: 0.5805, Val Accuracy: 0.6875
Epoch 11/150, Loss: 0.5679, Accuracy: 0.6777, Val Loss: 0.5837, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.5613, Accuracy: 0.6968, Val Loss: 0.5795, Val Accura



Mean Accuracy: 71.18%, (SD=0.09529690481695455)
Mean Loss: 0.5996, (SD=0.1354647814016106)
Mean Precision: 0.7375, (SD=0.08566806659345935)
Mean Recall: 0.7118, (SD=0.09529690481695455)
Mean F1 Score: 0.7003, (SD=0.10329263445760903)


## Architecture 4: TCN+MHA

In [16]:
# Temporal Convolutional Network (TCN) Block
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Multihead Attention Fusion Block
class MultiheadAttentionFusion(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.3):
        super(MultiheadAttentionFusion, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x.permute(1, 0, 2)  # Prepare dimensions for MultiheadAttention
        attn_output, _ = self.multihead_attn(x, x, x)
        attn_output = self.dropout(attn_output)
        attn_output = x + attn_output  # Add residual connection
        attn_output = self.layer_norm(attn_output)
        return attn_output.permute(1, 0, 2)

# Fusion Model with Multihead Attention
class FusionMultiheadAttentionModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, embed_dim, num_heads, num_classes=2):
        super(FusionMultiheadAttentionModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])
        tcn_output_dim = tcn_channels[-1]
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, embed_dim)
        self.mha = MultiheadAttentionFusion(embed_dim, num_heads)
        self.fc3 = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)
        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined = self.mha(combined)
        combined = combined.squeeze(1)
        output = self.fc3(combined)
        return output

# Training function
def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, embed_dim, num_heads, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionMultiheadAttentionModel(input_dim, tcn_channels, embed_dim, num_heads, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)
    writer = SummaryWriter(f'runs/fold_{fold_index}')

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Adjust the learning rate based on the validation loss

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()
    return model

# This function aggregates model predictions and calculates accuracy, precision, recall, and F1 score
def evaluate_model(models, X_test, y_test, batch_temperature=32):
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Group K-Fold evaluation can be performed
NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
gru_hidden_dim = 64
num_layers = 2
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, gru_hidden_dim, num_layers, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")



Epoch 1/150, Loss: 0.6767, Accuracy: 0.5943, Val Loss: 0.6359, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6515, Accuracy: 0.6173, Val Loss: 0.6057, Val Accuracy: 0.7222
Epoch 3/150, Loss: 0.5733, Accuracy: 0.6859, Val Loss: 0.5330, Val Accuracy: 0.6840
Epoch 4/150, Loss: 0.5103, Accuracy: 0.7324, Val Loss: 0.5675, Val Accuracy: 0.6910
Epoch 5/150, Loss: 0.5059, Accuracy: 0.7324, Val Loss: 0.5919, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.4925, Accuracy: 0.7589, Val Loss: 0.5076, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.4697, Accuracy: 0.7676, Val Loss: 0.5136, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.4694, Accuracy: 0.7772, Val Loss: 0.5074, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4687, Accuracy: 0.7824, Val Loss: 0.5036, Val Accuracy: 0.7188
Epoch 10/150, Loss: 0.4734, Accuracy: 0.7663, Val Loss: 0.5031, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4717, Accuracy: 0.7698, Val Loss: 0.4914, Val Accuracy: 0.7292
Epoch 12/150, Loss: 0.4694, Accuracy: 0.7667, Val Loss: 0.5375, Val Accura



Epoch 1/150, Loss: 0.6736, Accuracy: 0.5969, Val Loss: 0.6341, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6571, Accuracy: 0.6156, Val Loss: 0.6633, Val Accuracy: 0.6111
Epoch 3/150, Loss: 0.5561, Accuracy: 0.6994, Val Loss: 0.5190, Val Accuracy: 0.7257
Epoch 4/150, Loss: 0.4939, Accuracy: 0.7515, Val Loss: 0.5044, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.4931, Accuracy: 0.7533, Val Loss: 0.5681, Val Accuracy: 0.7049
Epoch 6/150, Loss: 0.4775, Accuracy: 0.7663, Val Loss: 0.5076, Val Accuracy: 0.7431
Epoch 7/150, Loss: 0.4740, Accuracy: 0.7698, Val Loss: 0.5153, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.4824, Accuracy: 0.7602, Val Loss: 0.5255, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4704, Accuracy: 0.7750, Val Loss: 0.4960, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4693, Accuracy: 0.7832, Val Loss: 0.4886, Val Accuracy: 0.7361
Epoch 11/150, Loss: 0.4725, Accuracy: 0.7676, Val Loss: 0.5221, Val Accuracy: 0.7292
Epoch 12/150, Loss: 0.4626, Accuracy: 0.7858, Val Loss: 0.5066, Val Accura



Epoch 1/150, Loss: 0.6880, Accuracy: 0.5808, Val Loss: 0.6303, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6618, Accuracy: 0.6108, Val Loss: 0.6252, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6183, Accuracy: 0.6633, Val Loss: 0.5891, Val Accuracy: 0.6701
Epoch 4/150, Loss: 0.5176, Accuracy: 0.7389, Val Loss: 0.5271, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.4910, Accuracy: 0.7554, Val Loss: 0.5120, Val Accuracy: 0.7396
Epoch 6/150, Loss: 0.4769, Accuracy: 0.7663, Val Loss: 0.5084, Val Accuracy: 0.7396
Epoch 7/150, Loss: 0.4754, Accuracy: 0.7732, Val Loss: 0.5141, Val Accuracy: 0.7153
Epoch 8/150, Loss: 0.4725, Accuracy: 0.7728, Val Loss: 0.5178, Val Accuracy: 0.7292
Epoch 9/150, Loss: 0.4671, Accuracy: 0.7767, Val Loss: 0.5100, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4676, Accuracy: 0.7837, Val Loss: 0.5149, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4652, Accuracy: 0.7711, Val Loss: 0.4967, Val Accuracy: 0.7188
Epoch 12/150, Loss: 0.4700, Accuracy: 0.7780, Val Loss: 0.5193, Val Accura



Epoch 1/150, Loss: 0.6666, Accuracy: 0.5969, Val Loss: 0.6527, Val Accuracy: 0.6875
Epoch 2/150, Loss: 0.6661, Accuracy: 0.5964, Val Loss: 0.6059, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.5676, Accuracy: 0.6868, Val Loss: 0.5882, Val Accuracy: 0.6806
Epoch 4/150, Loss: 0.4885, Accuracy: 0.7541, Val Loss: 0.4998, Val Accuracy: 0.7361
Epoch 5/150, Loss: 0.4936, Accuracy: 0.7524, Val Loss: 0.5541, Val Accuracy: 0.7083
Epoch 6/150, Loss: 0.4749, Accuracy: 0.7663, Val Loss: 0.5834, Val Accuracy: 0.7014
Epoch 7/150, Loss: 0.4727, Accuracy: 0.7750, Val Loss: 0.5055, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4703, Accuracy: 0.7785, Val Loss: 0.4902, Val Accuracy: 0.7292
Epoch 9/150, Loss: 0.4756, Accuracy: 0.7615, Val Loss: 0.5097, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4741, Accuracy: 0.7719, Val Loss: 0.5305, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4757, Accuracy: 0.7624, Val Loss: 0.4932, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.4655, Accuracy: 0.7724, Val Loss: 0.4945, Val Accura



Epoch 1/150, Loss: 0.6745, Accuracy: 0.6025, Val Loss: 0.6324, Val Accuracy: 0.6771
Epoch 2/150, Loss: 0.6542, Accuracy: 0.6247, Val Loss: 0.6043, Val Accuracy: 0.6806
Epoch 3/150, Loss: 0.5785, Accuracy: 0.6894, Val Loss: 0.5342, Val Accuracy: 0.7118
Epoch 4/150, Loss: 0.5138, Accuracy: 0.7281, Val Loss: 0.5248, Val Accuracy: 0.7257
Epoch 5/150, Loss: 0.4900, Accuracy: 0.7502, Val Loss: 0.5250, Val Accuracy: 0.7257
Epoch 6/150, Loss: 0.4833, Accuracy: 0.7589, Val Loss: 0.5387, Val Accuracy: 0.7118
Epoch 7/150, Loss: 0.4806, Accuracy: 0.7611, Val Loss: 0.5039, Val Accuracy: 0.7431
Epoch 8/150, Loss: 0.4708, Accuracy: 0.7763, Val Loss: 0.5143, Val Accuracy: 0.7257
Epoch 9/150, Loss: 0.4708, Accuracy: 0.7719, Val Loss: 0.5277, Val Accuracy: 0.7222
Epoch 10/150, Loss: 0.4590, Accuracy: 0.7750, Val Loss: 0.5151, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.4623, Accuracy: 0.7711, Val Loss: 0.5501, Val Accuracy: 0.7153
Epoch 12/150, Loss: 0.4574, Accuracy: 0.7772, Val Loss: 0.4932, Val Accura



Epoch 1/150, Loss: 0.6835, Accuracy: 0.5686, Val Loss: 0.6695, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6565, Accuracy: 0.5995, Val Loss: 0.6473, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6077, Accuracy: 0.6520, Val Loss: 0.5974, Val Accuracy: 0.6285
Epoch 4/150, Loss: 0.5751, Accuracy: 0.6838, Val Loss: 0.5251, Val Accuracy: 0.7604
Epoch 5/150, Loss: 0.5473, Accuracy: 0.7142, Val Loss: 0.5212, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5345, Accuracy: 0.7333, Val Loss: 0.5324, Val Accuracy: 0.7222
Epoch 7/150, Loss: 0.5213, Accuracy: 0.7359, Val Loss: 0.5310, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.5371, Accuracy: 0.7337, Val Loss: 0.4975, Val Accuracy: 0.7604
Epoch 9/150, Loss: 0.5118, Accuracy: 0.7589, Val Loss: 0.4961, Val Accuracy: 0.7708
Epoch 10/150, Loss: 0.5110, Accuracy: 0.7498, Val Loss: 0.5057, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.5102, Accuracy: 0.7515, Val Loss: 0.4978, Val Accuracy: 0.7708
Epoch 12/150, Loss: 0.5065, Accuracy: 0.7606, Val Loss: 0.4982, Val Accura



Epoch 1/150, Loss: 0.6779, Accuracy: 0.5695, Val Loss: 0.6689, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6557, Accuracy: 0.6160, Val Loss: 0.6588, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6156, Accuracy: 0.6460, Val Loss: 0.5733, Val Accuracy: 0.6840
Epoch 4/150, Loss: 0.5554, Accuracy: 0.7163, Val Loss: 0.5580, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.5416, Accuracy: 0.7194, Val Loss: 0.5528, Val Accuracy: 0.7257
Epoch 6/150, Loss: 0.5234, Accuracy: 0.7381, Val Loss: 0.5122, Val Accuracy: 0.7188
Epoch 7/150, Loss: 0.5139, Accuracy: 0.7507, Val Loss: 0.5006, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.5077, Accuracy: 0.7619, Val Loss: 0.5007, Val Accuracy: 0.7604
Epoch 9/150, Loss: 0.5068, Accuracy: 0.7572, Val Loss: 0.5082, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.5023, Accuracy: 0.7598, Val Loss: 0.4981, Val Accuracy: 0.7743
Epoch 11/150, Loss: 0.4999, Accuracy: 0.7698, Val Loss: 0.5070, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.5054, Accuracy: 0.7563, Val Loss: 0.5288, Val Accura



Epoch 1/150, Loss: 0.6748, Accuracy: 0.5821, Val Loss: 0.6485, Val Accuracy: 0.6007
Epoch 2/150, Loss: 0.6506, Accuracy: 0.6234, Val Loss: 0.6338, Val Accuracy: 0.6319
Epoch 3/150, Loss: 0.6053, Accuracy: 0.6625, Val Loss: 0.5662, Val Accuracy: 0.7083
Epoch 4/150, Loss: 0.5718, Accuracy: 0.6929, Val Loss: 0.5260, Val Accuracy: 0.7431
Epoch 5/150, Loss: 0.5405, Accuracy: 0.7324, Val Loss: 0.4985, Val Accuracy: 0.7639
Epoch 6/150, Loss: 0.5258, Accuracy: 0.7424, Val Loss: 0.5025, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5165, Accuracy: 0.7398, Val Loss: 0.5057, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.5105, Accuracy: 0.7454, Val Loss: 0.4988, Val Accuracy: 0.7396
Epoch 9/150, Loss: 0.5221, Accuracy: 0.7368, Val Loss: 0.4945, Val Accuracy: 0.7708
Epoch 10/150, Loss: 0.5099, Accuracy: 0.7511, Val Loss: 0.4962, Val Accuracy: 0.7569
Epoch 11/150, Loss: 0.5026, Accuracy: 0.7611, Val Loss: 0.5007, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.4948, Accuracy: 0.7680, Val Loss: 0.4895, Val Accura



Epoch 1/150, Loss: 0.6875, Accuracy: 0.5691, Val Loss: 0.6871, Val Accuracy: 0.5347
Epoch 2/150, Loss: 0.6729, Accuracy: 0.5752, Val Loss: 0.6623, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6628, Accuracy: 0.6034, Val Loss: 0.6690, Val Accuracy: 0.5486
Epoch 4/150, Loss: 0.6311, Accuracy: 0.6394, Val Loss: 0.5986, Val Accuracy: 0.7292
Epoch 5/150, Loss: 0.5675, Accuracy: 0.7016, Val Loss: 0.5751, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.5571, Accuracy: 0.7003, Val Loss: 0.5368, Val Accuracy: 0.7153
Epoch 7/150, Loss: 0.5179, Accuracy: 0.7398, Val Loss: 0.5132, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.5180, Accuracy: 0.7368, Val Loss: 0.4957, Val Accuracy: 0.7778
Epoch 9/150, Loss: 0.5179, Accuracy: 0.7493, Val Loss: 0.4939, Val Accuracy: 0.7396
Epoch 10/150, Loss: 0.5113, Accuracy: 0.7502, Val Loss: 0.4952, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.5069, Accuracy: 0.7554, Val Loss: 0.5114, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.5032, Accuracy: 0.7619, Val Loss: 0.4919, Val Accura



Epoch 1/150, Loss: 0.7050, Accuracy: 0.5669, Val Loss: 0.6762, Val Accuracy: 0.6111
Epoch 2/150, Loss: 0.6714, Accuracy: 0.5817, Val Loss: 0.6926, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6692, Accuracy: 0.5851, Val Loss: 0.6988, Val Accuracy: 0.5035
Epoch 4/150, Loss: 0.6442, Accuracy: 0.6255, Val Loss: 0.6047, Val Accuracy: 0.7049
Epoch 5/150, Loss: 0.5797, Accuracy: 0.6820, Val Loss: 0.5424, Val Accuracy: 0.7153
Epoch 6/150, Loss: 0.5408, Accuracy: 0.7341, Val Loss: 0.5027, Val Accuracy: 0.7778
Epoch 7/150, Loss: 0.5179, Accuracy: 0.7502, Val Loss: 0.4995, Val Accuracy: 0.7604
Epoch 8/150, Loss: 0.5279, Accuracy: 0.7328, Val Loss: 0.4938, Val Accuracy: 0.7778
Epoch 9/150, Loss: 0.5092, Accuracy: 0.7554, Val Loss: 0.4934, Val Accuracy: 0.7708
Epoch 10/150, Loss: 0.5125, Accuracy: 0.7498, Val Loss: 0.4876, Val Accuracy: 0.7812
Epoch 11/150, Loss: 0.5084, Accuracy: 0.7541, Val Loss: 0.4958, Val Accuracy: 0.7604
Epoch 12/150, Loss: 0.4923, Accuracy: 0.7754, Val Loss: 0.4850, Val Accura



Epoch 1/150, Loss: 0.6833, Accuracy: 0.5820, Val Loss: 0.7378, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6815, Accuracy: 0.5747, Val Loss: 0.6865, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6672, Accuracy: 0.5959, Val Loss: 0.6742, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6269, Accuracy: 0.6315, Val Loss: 0.5457, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.5485, Accuracy: 0.6923, Val Loss: 0.5840, Val Accuracy: 0.5972
Epoch 6/150, Loss: 0.5281, Accuracy: 0.7096, Val Loss: 0.5172, Val Accuracy: 0.7257
Epoch 7/150, Loss: 0.5197, Accuracy: 0.7170, Val Loss: 0.5144, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.5058, Accuracy: 0.7261, Val Loss: 0.4967, Val Accuracy: 0.7188
Epoch 9/150, Loss: 0.4950, Accuracy: 0.7300, Val Loss: 0.4808, Val Accuracy: 0.7569
Epoch 10/150, Loss: 0.4942, Accuracy: 0.7465, Val Loss: 0.4923, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4907, Accuracy: 0.7361, Val Loss: 0.5038, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.4818, Accuracy: 0.7565, Val Loss: 0.4807, Val Accura



Epoch 1/150, Loss: 0.6852, Accuracy: 0.5760, Val Loss: 0.6908, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6757, Accuracy: 0.5838, Val Loss: 0.6881, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6678, Accuracy: 0.5955, Val Loss: 0.6858, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6278, Accuracy: 0.6367, Val Loss: 0.6597, Val Accuracy: 0.6076
Epoch 5/150, Loss: 0.5559, Accuracy: 0.6931, Val Loss: 0.5119, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.5196, Accuracy: 0.7270, Val Loss: 0.5027, Val Accuracy: 0.7396
Epoch 7/150, Loss: 0.5173, Accuracy: 0.7235, Val Loss: 0.5094, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.5091, Accuracy: 0.7322, Val Loss: 0.4901, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.5105, Accuracy: 0.7322, Val Loss: 0.4919, Val Accuracy: 0.7535
Epoch 10/150, Loss: 0.4920, Accuracy: 0.7513, Val Loss: 0.4892, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4898, Accuracy: 0.7500, Val Loss: 0.4965, Val Accuracy: 0.7222
Epoch 12/150, Loss: 0.4847, Accuracy: 0.7509, Val Loss: 0.4770, Val Accura



Epoch 1/150, Loss: 0.6859, Accuracy: 0.5825, Val Loss: 0.6874, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6874, Accuracy: 0.5712, Val Loss: 0.6915, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6714, Accuracy: 0.5946, Val Loss: 0.6890, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6526, Accuracy: 0.6120, Val Loss: 0.5801, Val Accuracy: 0.7153
Epoch 5/150, Loss: 0.5624, Accuracy: 0.6736, Val Loss: 0.5342, Val Accuracy: 0.6944
Epoch 6/150, Loss: 0.5208, Accuracy: 0.7218, Val Loss: 0.4929, Val Accuracy: 0.7465
Epoch 7/150, Loss: 0.5100, Accuracy: 0.7257, Val Loss: 0.4897, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.5045, Accuracy: 0.7348, Val Loss: 0.5310, Val Accuracy: 0.7083
Epoch 9/150, Loss: 0.5044, Accuracy: 0.7335, Val Loss: 0.5084, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4926, Accuracy: 0.7543, Val Loss: 0.4987, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.5012, Accuracy: 0.7396, Val Loss: 0.4851, Val Accuracy: 0.7639
Epoch 12/150, Loss: 0.4895, Accuracy: 0.7426, Val Loss: 0.4964, Val Accura



Epoch 1/150, Loss: 0.6863, Accuracy: 0.5660, Val Loss: 0.7097, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6785, Accuracy: 0.5872, Val Loss: 0.6888, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6698, Accuracy: 0.5933, Val Loss: 0.7117, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6168, Accuracy: 0.6476, Val Loss: 0.5389, Val Accuracy: 0.7049
Epoch 5/150, Loss: 0.5369, Accuracy: 0.7036, Val Loss: 0.5138, Val Accuracy: 0.7396
Epoch 6/150, Loss: 0.5188, Accuracy: 0.7179, Val Loss: 0.5265, Val Accuracy: 0.7118
Epoch 7/150, Loss: 0.5028, Accuracy: 0.7378, Val Loss: 0.5298, Val Accuracy: 0.6979
Epoch 8/150, Loss: 0.5028, Accuracy: 0.7318, Val Loss: 0.4954, Val Accuracy: 0.7257
Epoch 9/150, Loss: 0.4992, Accuracy: 0.7387, Val Loss: 0.5029, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4915, Accuracy: 0.7513, Val Loss: 0.4861, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4878, Accuracy: 0.7491, Val Loss: 0.4893, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.4908, Accuracy: 0.7478, Val Loss: 0.4850, Val Accura



Epoch 1/150, Loss: 0.6942, Accuracy: 0.5599, Val Loss: 0.7056, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6789, Accuracy: 0.5898, Val Loss: 0.6956, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6777, Accuracy: 0.5734, Val Loss: 0.6905, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6698, Accuracy: 0.5829, Val Loss: 0.6632, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5920, Accuracy: 0.6771, Val Loss: 0.5065, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5266, Accuracy: 0.7157, Val Loss: 0.4881, Val Accuracy: 0.7500
Epoch 7/150, Loss: 0.4980, Accuracy: 0.7444, Val Loss: 0.4857, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.5060, Accuracy: 0.7365, Val Loss: 0.4927, Val Accuracy: 0.7500
Epoch 9/150, Loss: 0.4961, Accuracy: 0.7474, Val Loss: 0.4842, Val Accuracy: 0.7604
Epoch 10/150, Loss: 0.4944, Accuracy: 0.7452, Val Loss: 0.5002, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.4983, Accuracy: 0.7422, Val Loss: 0.5240, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.4896, Accuracy: 0.7413, Val Loss: 0.4965, Val Accura



Epoch 1/150, Loss: 0.6951, Accuracy: 0.5413, Val Loss: 0.6867, Val Accuracy: 0.5382
Epoch 2/150, Loss: 0.6861, Accuracy: 0.5465, Val Loss: 0.6765, Val Accuracy: 0.6319
Epoch 3/150, Loss: 0.6594, Accuracy: 0.5795, Val Loss: 0.6231, Val Accuracy: 0.6632
Epoch 4/150, Loss: 0.5996, Accuracy: 0.6438, Val Loss: 0.5769, Val Accuracy: 0.6771
Epoch 5/150, Loss: 0.5834, Accuracy: 0.6677, Val Loss: 0.5875, Val Accuracy: 0.6667
Epoch 6/150, Loss: 0.5745, Accuracy: 0.6629, Val Loss: 0.6092, Val Accuracy: 0.6319
Epoch 7/150, Loss: 0.5676, Accuracy: 0.6772, Val Loss: 0.5782, Val Accuracy: 0.7014
Epoch 8/150, Loss: 0.5710, Accuracy: 0.6838, Val Loss: 0.5831, Val Accuracy: 0.6910
Epoch 9/150, Loss: 0.5667, Accuracy: 0.6851, Val Loss: 0.6019, Val Accuracy: 0.6667
Epoch 10/150, Loss: 0.5595, Accuracy: 0.6842, Val Loss: 0.5794, Val Accuracy: 0.6771
Epoch 11/150, Loss: 0.5533, Accuracy: 0.6916, Val Loss: 0.5797, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.5548, Accuracy: 0.7024, Val Loss: 0.5794, Val Accura



Epoch 1/150, Loss: 0.6972, Accuracy: 0.5369, Val Loss: 0.6941, Val Accuracy: 0.5660
Epoch 2/150, Loss: 0.6860, Accuracy: 0.5426, Val Loss: 0.6885, Val Accuracy: 0.5972
Epoch 3/150, Loss: 0.6675, Accuracy: 0.5769, Val Loss: 0.6562, Val Accuracy: 0.5486
Epoch 4/150, Loss: 0.6317, Accuracy: 0.6047, Val Loss: 0.5884, Val Accuracy: 0.6979
Epoch 5/150, Loss: 0.5858, Accuracy: 0.6564, Val Loss: 0.5762, Val Accuracy: 0.6910
Epoch 6/150, Loss: 0.5742, Accuracy: 0.6681, Val Loss: 0.5870, Val Accuracy: 0.6632
Epoch 7/150, Loss: 0.5780, Accuracy: 0.6725, Val Loss: 0.5815, Val Accuracy: 0.6840
Epoch 8/150, Loss: 0.5691, Accuracy: 0.6772, Val Loss: 0.5808, Val Accuracy: 0.6771
Epoch 9/150, Loss: 0.5631, Accuracy: 0.6803, Val Loss: 0.5862, Val Accuracy: 0.6875
Epoch 10/150, Loss: 0.5650, Accuracy: 0.6890, Val Loss: 0.5769, Val Accuracy: 0.6979
Epoch 11/150, Loss: 0.5668, Accuracy: 0.6764, Val Loss: 0.5825, Val Accuracy: 0.6771
Epoch 12/150, Loss: 0.5556, Accuracy: 0.6937, Val Loss: 0.5806, Val Accura



Epoch 1/150, Loss: 0.6953, Accuracy: 0.5408, Val Loss: 0.6976, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6797, Accuracy: 0.5686, Val Loss: 0.6813, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6517, Accuracy: 0.6021, Val Loss: 0.6032, Val Accuracy: 0.6597
Epoch 4/150, Loss: 0.6056, Accuracy: 0.6399, Val Loss: 0.5776, Val Accuracy: 0.6736
Epoch 5/150, Loss: 0.5863, Accuracy: 0.6529, Val Loss: 0.5951, Val Accuracy: 0.6632
Epoch 6/150, Loss: 0.5806, Accuracy: 0.6707, Val Loss: 0.5735, Val Accuracy: 0.6944
Epoch 7/150, Loss: 0.5693, Accuracy: 0.6846, Val Loss: 0.5855, Val Accuracy: 0.6806
Epoch 8/150, Loss: 0.5671, Accuracy: 0.6898, Val Loss: 0.5758, Val Accuracy: 0.7014
Epoch 9/150, Loss: 0.5658, Accuracy: 0.6911, Val Loss: 0.6197, Val Accuracy: 0.6493
Epoch 10/150, Loss: 0.5635, Accuracy: 0.6977, Val Loss: 0.5815, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.5568, Accuracy: 0.6898, Val Loss: 0.5843, Val Accuracy: 0.6944
Epoch 12/150, Loss: 0.5620, Accuracy: 0.6820, Val Loss: 0.5769, Val Accura



Epoch 1/150, Loss: 0.7004, Accuracy: 0.5339, Val Loss: 0.6928, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6847, Accuracy: 0.5565, Val Loss: 0.6908, Val Accuracy: 0.5104
Epoch 3/150, Loss: 0.6644, Accuracy: 0.5704, Val Loss: 0.6425, Val Accuracy: 0.5278
Epoch 4/150, Loss: 0.5860, Accuracy: 0.6603, Val Loss: 0.5804, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.5802, Accuracy: 0.6642, Val Loss: 0.6072, Val Accuracy: 0.6562
Epoch 6/150, Loss: 0.5691, Accuracy: 0.6838, Val Loss: 0.5999, Val Accuracy: 0.6493
Epoch 7/150, Loss: 0.5705, Accuracy: 0.6838, Val Loss: 0.5864, Val Accuracy: 0.6840
Epoch 8/150, Loss: 0.5686, Accuracy: 0.6742, Val Loss: 0.5964, Val Accuracy: 0.6597
Epoch 9/150, Loss: 0.5646, Accuracy: 0.6950, Val Loss: 0.6021, Val Accuracy: 0.6528
Epoch 10/150, Loss: 0.5671, Accuracy: 0.6842, Val Loss: 0.5878, Val Accuracy: 0.6910
Epoch 11/150, Loss: 0.5556, Accuracy: 0.7011, Val Loss: 0.5963, Val Accuracy: 0.6632
Epoch 12/150, Loss: 0.5570, Accuracy: 0.7016, Val Loss: 0.5965, Val Accura



Epoch 1/150, Loss: 0.7015, Accuracy: 0.5391, Val Loss: 0.6926, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6894, Accuracy: 0.5452, Val Loss: 0.7155, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6732, Accuracy: 0.5643, Val Loss: 0.6354, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.6082, Accuracy: 0.6442, Val Loss: 0.5965, Val Accuracy: 0.6458
Epoch 5/150, Loss: 0.5940, Accuracy: 0.6603, Val Loss: 0.5855, Val Accuracy: 0.6910
Epoch 6/150, Loss: 0.5780, Accuracy: 0.6668, Val Loss: 0.5812, Val Accuracy: 0.6840
Epoch 7/150, Loss: 0.5695, Accuracy: 0.6764, Val Loss: 0.5786, Val Accuracy: 0.7083
Epoch 8/150, Loss: 0.5700, Accuracy: 0.6885, Val Loss: 0.5828, Val Accuracy: 0.6910
Epoch 9/150, Loss: 0.5627, Accuracy: 0.6885, Val Loss: 0.5746, Val Accuracy: 0.7049
Epoch 10/150, Loss: 0.5713, Accuracy: 0.6820, Val Loss: 0.5885, Val Accuracy: 0.6944
Epoch 11/150, Loss: 0.5553, Accuracy: 0.7024, Val Loss: 0.5978, Val Accuracy: 0.6736
Epoch 12/150, Loss: 0.5657, Accuracy: 0.6798, Val Loss: 0.5860, Val Accura



Mean Accuracy: 71.08%, (SD=0.09566260538106067)


## Architecture 5: TCN+Transformer Encoder

In [13]:
# TCN Block Definition
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Define the TCN block for each sensor
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Define the fusion and Transformer model
class FusionTransformerModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes=2):
        super(FusionTransformerModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])

        tcn_output_dim = tcn_channels[-1]  # The output dimension of the last TCN layer
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, (tcn_output_dim * input_dim) // num_heads * num_heads)  # Ensure divisibility

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=(tcn_output_dim * input_dim) // num_heads * num_heads, nhead=num_heads, dim_feedforward=transformer_hidden_dim),
            num_layers=num_layers
        )
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear((tcn_output_dim * input_dim) // num_heads * num_heads, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined = self.transformer_encoder(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the Transformer

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionTransformerModel(input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
transformer_hidden_dim = 64
num_layers = 2
num_heads = 4  # Set the number of heads to a value that divides the embedding dimension
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")




Epoch 1/150, Loss: 0.7181, Accuracy: 0.5599, Val Loss: 0.6371, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6803, Accuracy: 0.5925, Val Loss: 0.6533, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6641, Accuracy: 0.6069, Val Loss: 0.6333, Val Accuracy: 0.6875
Epoch 4/150, Loss: 0.5670, Accuracy: 0.7011, Val Loss: 0.5182, Val Accuracy: 0.6875
Epoch 5/150, Loss: 0.5060, Accuracy: 0.7411, Val Loss: 0.5179, Val Accuracy: 0.7326
Epoch 6/150, Loss: 0.4915, Accuracy: 0.7693, Val Loss: 0.5421, Val Accuracy: 0.7188
Epoch 7/150, Loss: 0.4878, Accuracy: 0.7650, Val Loss: 0.4994, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.4833, Accuracy: 0.7711, Val Loss: 0.4977, Val Accuracy: 0.7292
Epoch 9/150, Loss: 0.4856, Accuracy: 0.7672, Val Loss: 0.5034, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4751, Accuracy: 0.7776, Val Loss: 0.5391, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4722, Accuracy: 0.7663, Val Loss: 0.5522, Val Accuracy: 0.7188
Epoch 12/150, Loss: 0.4789, Accuracy: 0.7698, Val Loss: 0.5007, Val Accura



Epoch 1/150, Loss: 0.7224, Accuracy: 0.5465, Val Loss: 0.6330, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6825, Accuracy: 0.5930, Val Loss: 0.6569, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6665, Accuracy: 0.6030, Val Loss: 0.6221, Val Accuracy: 0.6667
Epoch 4/150, Loss: 0.5922, Accuracy: 0.6781, Val Loss: 0.5296, Val Accuracy: 0.7326
Epoch 5/150, Loss: 0.5223, Accuracy: 0.7420, Val Loss: 0.5297, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4968, Accuracy: 0.7624, Val Loss: 0.5201, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4855, Accuracy: 0.7711, Val Loss: 0.5100, Val Accuracy: 0.7222
Epoch 8/150, Loss: 0.4812, Accuracy: 0.7689, Val Loss: 0.5250, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4774, Accuracy: 0.7806, Val Loss: 0.5134, Val Accuracy: 0.7326
Epoch 10/150, Loss: 0.4876, Accuracy: 0.7619, Val Loss: 0.4989, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.4798, Accuracy: 0.7737, Val Loss: 0.4957, Val Accuracy: 0.7292
Epoch 12/150, Loss: 0.4747, Accuracy: 0.7732, Val Loss: 0.4921, Val Accura



Epoch 1/150, Loss: 0.7274, Accuracy: 0.5639, Val Loss: 0.6659, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6831, Accuracy: 0.5912, Val Loss: 0.6316, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6583, Accuracy: 0.6151, Val Loss: 0.6133, Val Accuracy: 0.6736
Epoch 4/150, Loss: 0.5994, Accuracy: 0.6751, Val Loss: 0.5330, Val Accuracy: 0.6806
Epoch 5/150, Loss: 0.5086, Accuracy: 0.7346, Val Loss: 0.5038, Val Accuracy: 0.7500
Epoch 6/150, Loss: 0.4983, Accuracy: 0.7476, Val Loss: 0.5349, Val Accuracy: 0.7188
Epoch 7/150, Loss: 0.4879, Accuracy: 0.7606, Val Loss: 0.5172, Val Accuracy: 0.7188
Epoch 8/150, Loss: 0.4834, Accuracy: 0.7724, Val Loss: 0.4988, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4811, Accuracy: 0.7745, Val Loss: 0.5137, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4815, Accuracy: 0.7693, Val Loss: 0.5038, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4746, Accuracy: 0.7785, Val Loss: 0.5300, Val Accuracy: 0.7118
Epoch 12/150, Loss: 0.4818, Accuracy: 0.7711, Val Loss: 0.5034, Val Accura



Epoch 1/150, Loss: 0.6922, Accuracy: 0.5795, Val Loss: 0.6334, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6741, Accuracy: 0.6003, Val Loss: 0.6635, Val Accuracy: 0.6250
Epoch 3/150, Loss: 0.6075, Accuracy: 0.6651, Val Loss: 0.5227, Val Accuracy: 0.7569
Epoch 4/150, Loss: 0.5178, Accuracy: 0.7394, Val Loss: 0.5117, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.4990, Accuracy: 0.7524, Val Loss: 0.5103, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4910, Accuracy: 0.7589, Val Loss: 0.5049, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.4894, Accuracy: 0.7585, Val Loss: 0.5154, Val Accuracy: 0.7222
Epoch 8/150, Loss: 0.4744, Accuracy: 0.7793, Val Loss: 0.5515, Val Accuracy: 0.7257
Epoch 9/150, Loss: 0.4820, Accuracy: 0.7750, Val Loss: 0.4978, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4785, Accuracy: 0.7702, Val Loss: 0.5038, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4714, Accuracy: 0.7754, Val Loss: 0.4938, Val Accuracy: 0.7326
Epoch 12/150, Loss: 0.4714, Accuracy: 0.7750, Val Loss: 0.4906, Val Accura



Epoch 1/150, Loss: 0.7071, Accuracy: 0.5630, Val Loss: 0.6595, Val Accuracy: 0.6667
Epoch 2/150, Loss: 0.6772, Accuracy: 0.5930, Val Loss: 0.6313, Val Accuracy: 0.6667
Epoch 3/150, Loss: 0.6740, Accuracy: 0.6034, Val Loss: 0.6267, Val Accuracy: 0.6806
Epoch 4/150, Loss: 0.5899, Accuracy: 0.6829, Val Loss: 0.5244, Val Accuracy: 0.7292
Epoch 5/150, Loss: 0.5100, Accuracy: 0.7394, Val Loss: 0.5105, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.4941, Accuracy: 0.7472, Val Loss: 0.5301, Val Accuracy: 0.7153
Epoch 7/150, Loss: 0.4790, Accuracy: 0.7693, Val Loss: 0.5226, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4851, Accuracy: 0.7615, Val Loss: 0.5224, Val Accuracy: 0.7326
Epoch 9/150, Loss: 0.4778, Accuracy: 0.7732, Val Loss: 0.5808, Val Accuracy: 0.7118
Epoch 10/150, Loss: 0.4814, Accuracy: 0.7750, Val Loss: 0.5343, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4791, Accuracy: 0.7763, Val Loss: 0.4939, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.4658, Accuracy: 0.7793, Val Loss: 0.5188, Val Accura



Epoch 1/150, Loss: 0.7204, Accuracy: 0.5443, Val Loss: 0.6818, Val Accuracy: 0.5972
Epoch 2/150, Loss: 0.6803, Accuracy: 0.5778, Val Loss: 0.6691, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6670, Accuracy: 0.5804, Val Loss: 0.6488, Val Accuracy: 0.5938
Epoch 4/150, Loss: 0.5998, Accuracy: 0.6564, Val Loss: 0.5576, Val Accuracy: 0.7118
Epoch 5/150, Loss: 0.5506, Accuracy: 0.7224, Val Loss: 0.5084, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5265, Accuracy: 0.7385, Val Loss: 0.4925, Val Accuracy: 0.7778
Epoch 7/150, Loss: 0.5174, Accuracy: 0.7515, Val Loss: 0.4978, Val Accuracy: 0.7639
Epoch 8/150, Loss: 0.5164, Accuracy: 0.7450, Val Loss: 0.5002, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.5246, Accuracy: 0.7459, Val Loss: 0.5018, Val Accuracy: 0.7639
Epoch 10/150, Loss: 0.5180, Accuracy: 0.7450, Val Loss: 0.5123, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.5036, Accuracy: 0.7632, Val Loss: 0.5453, Val Accuracy: 0.7361
Epoch 12/150, Loss: 0.5022, Accuracy: 0.7750, Val Loss: 0.4814, Val Accura



Epoch 1/150, Loss: 0.6964, Accuracy: 0.5626, Val Loss: 0.6781, Val Accuracy: 0.5833
Epoch 2/150, Loss: 0.6860, Accuracy: 0.5582, Val Loss: 0.6880, Val Accuracy: 0.5243
Epoch 3/150, Loss: 0.6685, Accuracy: 0.5838, Val Loss: 0.6603, Val Accuracy: 0.6250
Epoch 4/150, Loss: 0.6452, Accuracy: 0.6203, Val Loss: 0.5940, Val Accuracy: 0.7118
Epoch 5/150, Loss: 0.5806, Accuracy: 0.6820, Val Loss: 0.5469, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5474, Accuracy: 0.7081, Val Loss: 0.5212, Val Accuracy: 0.7431
Epoch 7/150, Loss: 0.5220, Accuracy: 0.7459, Val Loss: 0.4941, Val Accuracy: 0.7708
Epoch 8/150, Loss: 0.5246, Accuracy: 0.7437, Val Loss: 0.5107, Val Accuracy: 0.7674
Epoch 9/150, Loss: 0.5085, Accuracy: 0.7550, Val Loss: 0.4950, Val Accuracy: 0.7535
Epoch 10/150, Loss: 0.5235, Accuracy: 0.7376, Val Loss: 0.4975, Val Accuracy: 0.7743
Epoch 11/150, Loss: 0.5053, Accuracy: 0.7624, Val Loss: 0.4999, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.5032, Accuracy: 0.7650, Val Loss: 0.4841, Val Accura



Epoch 1/150, Loss: 0.7209, Accuracy: 0.5400, Val Loss: 0.6780, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6747, Accuracy: 0.5734, Val Loss: 0.6732, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6674, Accuracy: 0.6047, Val Loss: 0.6601, Val Accuracy: 0.6597
Epoch 4/150, Loss: 0.5959, Accuracy: 0.6764, Val Loss: 0.5140, Val Accuracy: 0.7431
Epoch 5/150, Loss: 0.5416, Accuracy: 0.7363, Val Loss: 0.5096, Val Accuracy: 0.7708
Epoch 6/150, Loss: 0.5184, Accuracy: 0.7563, Val Loss: 0.5091, Val Accuracy: 0.7569
Epoch 7/150, Loss: 0.5173, Accuracy: 0.7493, Val Loss: 0.5263, Val Accuracy: 0.7188
Epoch 8/150, Loss: 0.5162, Accuracy: 0.7567, Val Loss: 0.4896, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.5160, Accuracy: 0.7576, Val Loss: 0.4885, Val Accuracy: 0.7812
Epoch 10/150, Loss: 0.5053, Accuracy: 0.7624, Val Loss: 0.4904, Val Accuracy: 0.7812
Epoch 11/150, Loss: 0.4939, Accuracy: 0.7672, Val Loss: 0.5030, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4996, Accuracy: 0.7767, Val Loss: 0.4999, Val Accura



Epoch 1/150, Loss: 0.7013, Accuracy: 0.5678, Val Loss: 0.6753, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6741, Accuracy: 0.5778, Val Loss: 0.6758, Val Accuracy: 0.5729
Epoch 3/150, Loss: 0.6631, Accuracy: 0.5973, Val Loss: 0.6448, Val Accuracy: 0.6910
Epoch 4/150, Loss: 0.6046, Accuracy: 0.6733, Val Loss: 0.5257, Val Accuracy: 0.7812
Epoch 5/150, Loss: 0.5454, Accuracy: 0.7202, Val Loss: 0.5336, Val Accuracy: 0.7535
Epoch 6/150, Loss: 0.5283, Accuracy: 0.7376, Val Loss: 0.5134, Val Accuracy: 0.7500
Epoch 7/150, Loss: 0.5267, Accuracy: 0.7411, Val Loss: 0.5032, Val Accuracy: 0.7604
Epoch 8/150, Loss: 0.5308, Accuracy: 0.7324, Val Loss: 0.4871, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.5192, Accuracy: 0.7489, Val Loss: 0.4924, Val Accuracy: 0.7674
Epoch 10/150, Loss: 0.5104, Accuracy: 0.7619, Val Loss: 0.4908, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.5133, Accuracy: 0.7437, Val Loss: 0.4956, Val Accuracy: 0.7674
Epoch 12/150, Loss: 0.5030, Accuracy: 0.7632, Val Loss: 0.5000, Val Accura



Epoch 1/150, Loss: 0.7248, Accuracy: 0.5474, Val Loss: 0.6793, Val Accuracy: 0.5938
Epoch 2/150, Loss: 0.6785, Accuracy: 0.5695, Val Loss: 0.6735, Val Accuracy: 0.5938
Epoch 3/150, Loss: 0.6673, Accuracy: 0.5912, Val Loss: 0.6520, Val Accuracy: 0.6250
Epoch 4/150, Loss: 0.5987, Accuracy: 0.6716, Val Loss: 0.5147, Val Accuracy: 0.7639
Epoch 5/150, Loss: 0.5406, Accuracy: 0.7272, Val Loss: 0.5177, Val Accuracy: 0.7639
Epoch 6/150, Loss: 0.5254, Accuracy: 0.7463, Val Loss: 0.5120, Val Accuracy: 0.7639
Epoch 7/150, Loss: 0.5305, Accuracy: 0.7328, Val Loss: 0.5179, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.5149, Accuracy: 0.7537, Val Loss: 0.5052, Val Accuracy: 0.7639
Epoch 9/150, Loss: 0.5150, Accuracy: 0.7515, Val Loss: 0.5010, Val Accuracy: 0.7708
Epoch 10/150, Loss: 0.5144, Accuracy: 0.7593, Val Loss: 0.4991, Val Accuracy: 0.7639
Epoch 11/150, Loss: 0.5067, Accuracy: 0.7672, Val Loss: 0.4968, Val Accuracy: 0.7569
Epoch 12/150, Loss: 0.5079, Accuracy: 0.7606, Val Loss: 0.4955, Val Accura



Epoch 1/150, Loss: 0.7109, Accuracy: 0.5486, Val Loss: 0.6964, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6879, Accuracy: 0.5686, Val Loss: 0.6914, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6820, Accuracy: 0.5794, Val Loss: 0.6958, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6318, Accuracy: 0.6259, Val Loss: 0.5342, Val Accuracy: 0.7118
Epoch 5/150, Loss: 0.5269, Accuracy: 0.7122, Val Loss: 0.5009, Val Accuracy: 0.7465
Epoch 6/150, Loss: 0.5048, Accuracy: 0.7335, Val Loss: 0.4841, Val Accuracy: 0.7569
Epoch 7/150, Loss: 0.5012, Accuracy: 0.7400, Val Loss: 0.4847, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.5101, Accuracy: 0.7240, Val Loss: 0.4957, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4956, Accuracy: 0.7426, Val Loss: 0.5001, Val Accuracy: 0.7361
Epoch 10/150, Loss: 0.4980, Accuracy: 0.7500, Val Loss: 0.4952, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4958, Accuracy: 0.7491, Val Loss: 0.4847, Val Accuracy: 0.7326
Epoch 12/150, Loss: 0.4890, Accuracy: 0.7591, Val Loss: 0.4923, Val Accura



Epoch 1/150, Loss: 0.7066, Accuracy: 0.5560, Val Loss: 0.7050, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6945, Accuracy: 0.5677, Val Loss: 0.7021, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6850, Accuracy: 0.5651, Val Loss: 0.6935, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6804, Accuracy: 0.5786, Val Loss: 0.6890, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.6338, Accuracy: 0.6350, Val Loss: 0.5671, Val Accuracy: 0.7014
Epoch 6/150, Loss: 0.5480, Accuracy: 0.6927, Val Loss: 0.5040, Val Accuracy: 0.7604
Epoch 7/150, Loss: 0.5231, Accuracy: 0.7261, Val Loss: 0.5128, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.5126, Accuracy: 0.7426, Val Loss: 0.5035, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.5077, Accuracy: 0.7292, Val Loss: 0.5296, Val Accuracy: 0.7188
Epoch 10/150, Loss: 0.5080, Accuracy: 0.7331, Val Loss: 0.4870, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.4952, Accuracy: 0.7439, Val Loss: 0.5372, Val Accuracy: 0.7188
Epoch 12/150, Loss: 0.4852, Accuracy: 0.7565, Val Loss: 0.5033, Val Accura



Epoch 1/150, Loss: 0.6993, Accuracy: 0.5538, Val Loss: 0.6993, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6807, Accuracy: 0.5842, Val Loss: 0.7252, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6794, Accuracy: 0.5890, Val Loss: 0.6917, Val Accuracy: 0.5729
Epoch 4/150, Loss: 0.6798, Accuracy: 0.5981, Val Loss: 0.6935, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.6270, Accuracy: 0.6272, Val Loss: 0.5182, Val Accuracy: 0.7396
Epoch 6/150, Loss: 0.5323, Accuracy: 0.7192, Val Loss: 0.4954, Val Accuracy: 0.7465
Epoch 7/150, Loss: 0.5175, Accuracy: 0.7409, Val Loss: 0.5069, Val Accuracy: 0.7431
Epoch 8/150, Loss: 0.5132, Accuracy: 0.7209, Val Loss: 0.5470, Val Accuracy: 0.6562
Epoch 9/150, Loss: 0.5047, Accuracy: 0.7461, Val Loss: 0.4992, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.4934, Accuracy: 0.7483, Val Loss: 0.5180, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.4920, Accuracy: 0.7535, Val Loss: 0.4920, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4811, Accuracy: 0.7491, Val Loss: 0.4798, Val Accura



Epoch 1/150, Loss: 0.7238, Accuracy: 0.5477, Val Loss: 0.7006, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6917, Accuracy: 0.5725, Val Loss: 0.6987, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6761, Accuracy: 0.5838, Val Loss: 0.6918, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6742, Accuracy: 0.5751, Val Loss: 0.6915, Val Accuracy: 0.5590
Epoch 5/150, Loss: 0.5780, Accuracy: 0.6832, Val Loss: 0.5200, Val Accuracy: 0.6875
Epoch 6/150, Loss: 0.5119, Accuracy: 0.7296, Val Loss: 0.4987, Val Accuracy: 0.7361
Epoch 7/150, Loss: 0.5090, Accuracy: 0.7309, Val Loss: 0.5100, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.5060, Accuracy: 0.7335, Val Loss: 0.5234, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4968, Accuracy: 0.7370, Val Loss: 0.5094, Val Accuracy: 0.7292
Epoch 10/150, Loss: 0.4980, Accuracy: 0.7387, Val Loss: 0.5104, Val Accuracy: 0.7292
Epoch 11/150, Loss: 0.4929, Accuracy: 0.7452, Val Loss: 0.5092, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.4888, Accuracy: 0.7496, Val Loss: 0.4994, Val Accura



Epoch 1/150, Loss: 0.7030, Accuracy: 0.5556, Val Loss: 0.7181, Val Accuracy: 0.5590
Epoch 2/150, Loss: 0.6891, Accuracy: 0.5686, Val Loss: 0.7115, Val Accuracy: 0.5590
Epoch 3/150, Loss: 0.6799, Accuracy: 0.5951, Val Loss: 0.6984, Val Accuracy: 0.5590
Epoch 4/150, Loss: 0.6529, Accuracy: 0.6137, Val Loss: 0.5529, Val Accuracy: 0.7188
Epoch 5/150, Loss: 0.5486, Accuracy: 0.6819, Val Loss: 0.5168, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.5229, Accuracy: 0.7127, Val Loss: 0.4971, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5127, Accuracy: 0.7287, Val Loss: 0.5085, Val Accuracy: 0.7257
Epoch 8/150, Loss: 0.5200, Accuracy: 0.7244, Val Loss: 0.4858, Val Accuracy: 0.7535
Epoch 9/150, Loss: 0.5039, Accuracy: 0.7405, Val Loss: 0.4958, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.5045, Accuracy: 0.7387, Val Loss: 0.5086, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4971, Accuracy: 0.7509, Val Loss: 0.4844, Val Accuracy: 0.7465
Epoch 12/150, Loss: 0.4975, Accuracy: 0.7431, Val Loss: 0.5390, Val Accura



Epoch 1/150, Loss: 0.7324, Accuracy: 0.5182, Val Loss: 0.6963, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.7017, Accuracy: 0.5304, Val Loss: 0.7075, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6872, Accuracy: 0.5504, Val Loss: 0.6916, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6836, Accuracy: 0.5469, Val Loss: 0.6463, Val Accuracy: 0.6250
Epoch 5/150, Loss: 0.6109, Accuracy: 0.6438, Val Loss: 0.5790, Val Accuracy: 0.6910
Epoch 6/150, Loss: 0.5973, Accuracy: 0.6642, Val Loss: 0.5888, Val Accuracy: 0.6597
Epoch 7/150, Loss: 0.5842, Accuracy: 0.6607, Val Loss: 0.5769, Val Accuracy: 0.6944
Epoch 8/150, Loss: 0.5732, Accuracy: 0.6746, Val Loss: 0.5986, Val Accuracy: 0.6736
Epoch 9/150, Loss: 0.5752, Accuracy: 0.6833, Val Loss: 0.5805, Val Accuracy: 0.6875
Epoch 10/150, Loss: 0.5693, Accuracy: 0.6868, Val Loss: 0.5766, Val Accuracy: 0.6944
Epoch 11/150, Loss: 0.5596, Accuracy: 0.6959, Val Loss: 0.5797, Val Accuracy: 0.7222
Epoch 12/150, Loss: 0.5579, Accuracy: 0.7024, Val Loss: 0.5767, Val Accura



Epoch 1/150, Loss: 0.7179, Accuracy: 0.5300, Val Loss: 0.6917, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.7008, Accuracy: 0.5226, Val Loss: 0.6925, Val Accuracy: 0.5174
Epoch 3/150, Loss: 0.6883, Accuracy: 0.5543, Val Loss: 0.6876, Val Accuracy: 0.5035
Epoch 4/150, Loss: 0.6632, Accuracy: 0.5760, Val Loss: 0.6226, Val Accuracy: 0.6215
Epoch 5/150, Loss: 0.5923, Accuracy: 0.6468, Val Loss: 0.5734, Val Accuracy: 0.6875
Epoch 6/150, Loss: 0.5925, Accuracy: 0.6560, Val Loss: 0.5847, Val Accuracy: 0.6875
Epoch 7/150, Loss: 0.5816, Accuracy: 0.6712, Val Loss: 0.5847, Val Accuracy: 0.6771
Epoch 8/150, Loss: 0.5768, Accuracy: 0.6677, Val Loss: 0.5981, Val Accuracy: 0.6875
Epoch 9/150, Loss: 0.5654, Accuracy: 0.6877, Val Loss: 0.5723, Val Accuracy: 0.6736
Epoch 10/150, Loss: 0.5693, Accuracy: 0.6911, Val Loss: 0.5689, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.5634, Accuracy: 0.6920, Val Loss: 0.5745, Val Accuracy: 0.7083
Epoch 12/150, Loss: 0.5595, Accuracy: 0.6959, Val Loss: 0.6194, Val Accura



Epoch 1/150, Loss: 0.7376, Accuracy: 0.5148, Val Loss: 0.6936, Val Accuracy: 0.5035
Epoch 2/150, Loss: 0.6946, Accuracy: 0.5404, Val Loss: 0.6861, Val Accuracy: 0.5833
Epoch 3/150, Loss: 0.6924, Accuracy: 0.5382, Val Loss: 0.6849, Val Accuracy: 0.5868
Epoch 4/150, Loss: 0.6844, Accuracy: 0.5504, Val Loss: 0.6832, Val Accuracy: 0.5243
Epoch 5/150, Loss: 0.6433, Accuracy: 0.6121, Val Loss: 0.5769, Val Accuracy: 0.6736
Epoch 6/150, Loss: 0.5856, Accuracy: 0.6633, Val Loss: 0.5802, Val Accuracy: 0.6771
Epoch 7/150, Loss: 0.5794, Accuracy: 0.6664, Val Loss: 0.5834, Val Accuracy: 0.6944
Epoch 8/150, Loss: 0.5667, Accuracy: 0.6933, Val Loss: 0.5819, Val Accuracy: 0.6979
Epoch 9/150, Loss: 0.5716, Accuracy: 0.6929, Val Loss: 0.5893, Val Accuracy: 0.6979
Epoch 10/150, Loss: 0.5664, Accuracy: 0.6916, Val Loss: 0.5896, Val Accuracy: 0.6910
Epoch 11/150, Loss: 0.5651, Accuracy: 0.6959, Val Loss: 0.5777, Val Accuracy: 0.6771
Epoch 12/150, Loss: 0.5594, Accuracy: 0.7007, Val Loss: 0.5779, Val Accura



Epoch 1/150, Loss: 0.7200, Accuracy: 0.5274, Val Loss: 0.6868, Val Accuracy: 0.5208
Epoch 2/150, Loss: 0.6952, Accuracy: 0.5321, Val Loss: 0.6869, Val Accuracy: 0.5347
Epoch 3/150, Loss: 0.6870, Accuracy: 0.5517, Val Loss: 0.6849, Val Accuracy: 0.5104
Epoch 4/150, Loss: 0.6754, Accuracy: 0.5678, Val Loss: 0.6396, Val Accuracy: 0.6528
Epoch 5/150, Loss: 0.6213, Accuracy: 0.6290, Val Loss: 0.5788, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.5821, Accuracy: 0.6646, Val Loss: 0.5722, Val Accuracy: 0.6771
Epoch 7/150, Loss: 0.5810, Accuracy: 0.6677, Val Loss: 0.5913, Val Accuracy: 0.6632
Epoch 8/150, Loss: 0.5729, Accuracy: 0.6768, Val Loss: 0.5905, Val Accuracy: 0.6771
Epoch 9/150, Loss: 0.5667, Accuracy: 0.6942, Val Loss: 0.5833, Val Accuracy: 0.7049
Epoch 10/150, Loss: 0.5707, Accuracy: 0.6807, Val Loss: 0.6643, Val Accuracy: 0.5590
Epoch 11/150, Loss: 0.5678, Accuracy: 0.6859, Val Loss: 0.6105, Val Accuracy: 0.6806
Epoch 12/150, Loss: 0.5646, Accuracy: 0.6907, Val Loss: 0.5735, Val Accura



Epoch 1/150, Loss: 0.7267, Accuracy: 0.5230, Val Loss: 0.7038, Val Accuracy: 0.5174
Epoch 2/150, Loss: 0.6971, Accuracy: 0.5443, Val Loss: 0.6869, Val Accuracy: 0.5625
Epoch 3/150, Loss: 0.6881, Accuracy: 0.5526, Val Loss: 0.6889, Val Accuracy: 0.5174
Epoch 4/150, Loss: 0.6779, Accuracy: 0.5586, Val Loss: 0.6828, Val Accuracy: 0.5174
Epoch 5/150, Loss: 0.6216, Accuracy: 0.6303, Val Loss: 0.5772, Val Accuracy: 0.6771
Epoch 6/150, Loss: 0.5887, Accuracy: 0.6668, Val Loss: 0.5796, Val Accuracy: 0.6979
Epoch 7/150, Loss: 0.5787, Accuracy: 0.6742, Val Loss: 0.5805, Val Accuracy: 0.6667
Epoch 8/150, Loss: 0.5817, Accuracy: 0.6681, Val Loss: 0.5809, Val Accuracy: 0.6806
Epoch 9/150, Loss: 0.5727, Accuracy: 0.6833, Val Loss: 0.5969, Val Accuracy: 0.7118
Epoch 10/150, Loss: 0.5706, Accuracy: 0.6825, Val Loss: 0.6036, Val Accuracy: 0.6771
Epoch 11/150, Loss: 0.5629, Accuracy: 0.6972, Val Loss: 0.5781, Val Accuracy: 0.7118
Epoch 12/150, Loss: 0.5582, Accuracy: 0.6981, Val Loss: 0.5829, Val Accura



Mean Accuracy: 69.67%, (SD=0.08469296489922801)
Mean Loss: 0.6286, (SD=0.10823353908416396)
Mean Precision: 0.7158, (SD=0.08271813243129823)
Mean Recall: 0.6967, (SD=0.08469296489922801)
Mean F1 Score: 0.6894, (SD=0.09059199653311845)


## Architecture 6: ResNet + GRU Fusion

In [14]:
# ResNet1D Block Definition
class BasicBlock1D(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock1D, self).__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes),
            )

    def forward(self, x):
        residual = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class ResNet1D(nn.Module):
    def __init__(self, block, layers, in_channels=1):
        super(ResNet1D, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    def _make_layer(self, block, planes, blocks, stride=1):
        layers = []
        layers.append(block(self.in_planes, planes, stride))
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

# Define the ResNet block for each sensor
class ResNetBlock(nn.Module):
    def __init__(self, input_size):
        super(ResNetBlock, self).__init__()
        self.resnet = ResNet1D(BasicBlock1D, [2, 2, 2, 2], in_channels=input_size)

    def forward(self, x):
        return self.resnet(x)

# Define the fusion and GRU model
class FusionGRUModel(nn.Module):
    def __init__(self, input_dim, resnet_channels, gru_hidden_dim, num_layers, num_classes=2):
        super(FusionGRUModel, self).__init__()
        self.frontends = nn.ModuleList([ResNetBlock(input_size=1) for _ in range(input_dim)])

        resnet_output_dim = resnet_channels[-1]  # The output dimension of the last ResNet layer
        self.fc1 = nn.Linear(resnet_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, resnet_output_dim * input_dim)

        self.gru = nn.GRU(resnet_output_dim * input_dim, gru_hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(gru_hidden_dim * 2, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined, _ = self.gru(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the GRU

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, resnet_channels, gru_hidden_dim, num_layers, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionGRUModel(input_dim, resnet_channels, gru_hidden_dim, num_layers, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
resnet_channels = [64, 128, 256, 512]  # ResNet channels
gru_hidden_dim = 64
num_layers = 2
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, resnet_channels, gru_hidden_dim, num_layers, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")




Epoch 1/150, Loss: 0.5153, Accuracy: 0.7428, Val Loss: 0.4767, Val Accuracy: 0.7500
Epoch 2/150, Loss: 0.4873, Accuracy: 0.7602, Val Loss: 0.5382, Val Accuracy: 0.7257
Epoch 3/150, Loss: 0.5072, Accuracy: 0.7454, Val Loss: 0.4978, Val Accuracy: 0.7396
Epoch 4/150, Loss: 0.4796, Accuracy: 0.7728, Val Loss: 0.5252, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.4563, Accuracy: 0.7906, Val Loss: 0.5313, Val Accuracy: 0.7257
Epoch 6/150, Loss: 0.4479, Accuracy: 0.7880, Val Loss: 0.4743, Val Accuracy: 0.7500
Epoch 7/150, Loss: 0.4532, Accuracy: 0.7884, Val Loss: 0.4997, Val Accuracy: 0.7118
Epoch 8/150, Loss: 0.4541, Accuracy: 0.7871, Val Loss: 0.4976, Val Accuracy: 0.7639
Epoch 9/150, Loss: 0.4449, Accuracy: 0.8032, Val Loss: 0.5698, Val Accuracy: 0.6979
Epoch 10/150, Loss: 0.4592, Accuracy: 0.7841, Val Loss: 0.5319, Val Accuracy: 0.7153
Epoch 11/150, Loss: 0.4431, Accuracy: 0.7911, Val Loss: 0.5432, Val Accuracy: 0.7083
Epoch 12/150, Loss: 0.4661, Accuracy: 0.7845, Val Loss: 0.5086, Val Accura



Epoch 1/150, Loss: 0.5100, Accuracy: 0.7472, Val Loss: 0.5102, Val Accuracy: 0.6979
Epoch 2/150, Loss: 0.4861, Accuracy: 0.7659, Val Loss: 0.5444, Val Accuracy: 0.7049
Epoch 3/150, Loss: 0.4943, Accuracy: 0.7602, Val Loss: 0.5026, Val Accuracy: 0.7326
Epoch 4/150, Loss: 0.4723, Accuracy: 0.7850, Val Loss: 0.5404, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.4992, Accuracy: 0.7750, Val Loss: 0.5042, Val Accuracy: 0.7292
Epoch 6/150, Loss: 0.4952, Accuracy: 0.7672, Val Loss: 0.5143, Val Accuracy: 0.7118
Epoch 7/150, Loss: 0.4854, Accuracy: 0.7824, Val Loss: 0.5095, Val Accuracy: 0.7396
Epoch 8/150, Loss: 0.4776, Accuracy: 0.7732, Val Loss: 0.5161, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4556, Accuracy: 0.7928, Val Loss: 0.5053, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4675, Accuracy: 0.7880, Val Loss: 0.5033, Val Accuracy: 0.7431
Epoch 11/150, Loss: 0.4563, Accuracy: 0.7919, Val Loss: 0.5096, Val Accuracy: 0.7604
Epoch 12/150, Loss: 0.4221, Accuracy: 0.8262, Val Loss: 0.5005, Val Accura



Epoch 1/150, Loss: 0.5279, Accuracy: 0.7433, Val Loss: 0.4551, Val Accuracy: 0.8021
Epoch 2/150, Loss: 0.4805, Accuracy: 0.7793, Val Loss: 0.4926, Val Accuracy: 0.7465
Epoch 3/150, Loss: 0.4555, Accuracy: 0.7937, Val Loss: 0.5231, Val Accuracy: 0.7292
Epoch 4/150, Loss: 0.4647, Accuracy: 0.7828, Val Loss: 0.5306, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.4501, Accuracy: 0.7919, Val Loss: 0.4894, Val Accuracy: 0.7431
Epoch 6/150, Loss: 0.4610, Accuracy: 0.7906, Val Loss: 0.4741, Val Accuracy: 0.7569
Epoch 7/150, Loss: 0.4475, Accuracy: 0.8019, Val Loss: 0.4744, Val Accuracy: 0.7674
Epoch 8/150, Loss: 0.4533, Accuracy: 0.7971, Val Loss: 0.4797, Val Accuracy: 0.7708
Epoch 9/150, Loss: 0.4134, Accuracy: 0.8175, Val Loss: 0.4990, Val Accuracy: 0.7569
Epoch 10/150, Loss: 0.4188, Accuracy: 0.8145, Val Loss: 0.4799, Val Accuracy: 0.7674
Epoch 11/150, Loss: 0.4098, Accuracy: 0.8158, Val Loss: 0.5109, Val Accuracy: 0.7535
Early stopping at epoch 11
Total parameters: 20960706
FLOPs: 20103296
Fold



Epoch 1/150, Loss: 0.5112, Accuracy: 0.7554, Val Loss: 0.5288, Val Accuracy: 0.7361
Epoch 2/150, Loss: 0.4666, Accuracy: 0.7867, Val Loss: 0.4698, Val Accuracy: 0.7674
Epoch 3/150, Loss: 0.4933, Accuracy: 0.7615, Val Loss: 0.4820, Val Accuracy: 0.7465
Epoch 4/150, Loss: 0.4590, Accuracy: 0.7897, Val Loss: 0.5095, Val Accuracy: 0.7396
Epoch 5/150, Loss: 0.4617, Accuracy: 0.7867, Val Loss: 0.5169, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.4497, Accuracy: 0.7997, Val Loss: 0.5032, Val Accuracy: 0.7326
Epoch 7/150, Loss: 0.4438, Accuracy: 0.8015, Val Loss: 0.5902, Val Accuracy: 0.7361
Epoch 8/150, Loss: 0.4615, Accuracy: 0.7871, Val Loss: 0.5244, Val Accuracy: 0.7326
Epoch 9/150, Loss: 0.4269, Accuracy: 0.8054, Val Loss: 0.5283, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.3990, Accuracy: 0.8284, Val Loss: 0.5085, Val Accuracy: 0.7639
Epoch 11/150, Loss: 0.3746, Accuracy: 0.8414, Val Loss: 0.5098, Val Accuracy: 0.7674
Epoch 12/150, Loss: 0.3772, Accuracy: 0.8328, Val Loss: 0.5274, Val Accura



Epoch 1/150, Loss: 0.5105, Accuracy: 0.7459, Val Loss: 0.5155, Val Accuracy: 0.7431
Epoch 2/150, Loss: 0.4807, Accuracy: 0.7663, Val Loss: 0.5276, Val Accuracy: 0.7188
Epoch 3/150, Loss: 0.4652, Accuracy: 0.7754, Val Loss: 0.6094, Val Accuracy: 0.6597
Epoch 4/150, Loss: 0.4640, Accuracy: 0.7745, Val Loss: 0.4917, Val Accuracy: 0.7743
Epoch 5/150, Loss: 0.4818, Accuracy: 0.7606, Val Loss: 0.5044, Val Accuracy: 0.7153
Epoch 6/150, Loss: 0.4791, Accuracy: 0.7663, Val Loss: 0.5746, Val Accuracy: 0.6319
Epoch 7/150, Loss: 0.4801, Accuracy: 0.7728, Val Loss: 0.4913, Val Accuracy: 0.7292
Epoch 8/150, Loss: 0.4674, Accuracy: 0.7793, Val Loss: 0.5049, Val Accuracy: 0.7292
Epoch 9/150, Loss: 0.4782, Accuracy: 0.7672, Val Loss: 0.5127, Val Accuracy: 0.7257
Epoch 10/150, Loss: 0.4663, Accuracy: 0.7867, Val Loss: 0.5346, Val Accuracy: 0.7396
Epoch 11/150, Loss: 0.4610, Accuracy: 0.7824, Val Loss: 0.4987, Val Accuracy: 0.7222
Epoch 12/150, Loss: 0.4654, Accuracy: 0.7871, Val Loss: 0.5238, Val Accura



Epoch 1/150, Loss: 0.5367, Accuracy: 0.7354, Val Loss: 0.4876, Val Accuracy: 0.7396
Epoch 2/150, Loss: 0.5222, Accuracy: 0.7515, Val Loss: 0.5121, Val Accuracy: 0.7500
Epoch 3/150, Loss: 0.5227, Accuracy: 0.7567, Val Loss: 0.5357, Val Accuracy: 0.7188
Epoch 4/150, Loss: 0.5068, Accuracy: 0.7567, Val Loss: 0.5185, Val Accuracy: 0.7604
Epoch 5/150, Loss: 0.5031, Accuracy: 0.7798, Val Loss: 0.5222, Val Accuracy: 0.7639
Epoch 6/150, Loss: 0.5073, Accuracy: 0.7728, Val Loss: 0.4984, Val Accuracy: 0.7535
Epoch 7/150, Loss: 0.5011, Accuracy: 0.7819, Val Loss: 0.5308, Val Accuracy: 0.7431
Epoch 8/150, Loss: 0.4900, Accuracy: 0.7806, Val Loss: 0.5052, Val Accuracy: 0.7604
Epoch 9/150, Loss: 0.4551, Accuracy: 0.8050, Val Loss: 0.4917, Val Accuracy: 0.7639
Epoch 10/150, Loss: 0.4774, Accuracy: 0.7884, Val Loss: 0.4950, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4592, Accuracy: 0.8006, Val Loss: 0.4970, Val Accuracy: 0.7639
Early stopping at epoch 11
Total parameters: 20960706
FLOPs: 20103296
Fold



Epoch 1/150, Loss: 0.5463, Accuracy: 0.7263, Val Loss: 0.4994, Val Accuracy: 0.7674
Epoch 2/150, Loss: 0.5309, Accuracy: 0.7394, Val Loss: 0.5022, Val Accuracy: 0.7292
Epoch 3/150, Loss: 0.5044, Accuracy: 0.7554, Val Loss: 0.5492, Val Accuracy: 0.7153
Epoch 4/150, Loss: 0.4995, Accuracy: 0.7606, Val Loss: 0.5184, Val Accuracy: 0.7500
Epoch 5/150, Loss: 0.4945, Accuracy: 0.7680, Val Loss: 0.5606, Val Accuracy: 0.7257
Epoch 6/150, Loss: 0.5045, Accuracy: 0.7632, Val Loss: 0.5702, Val Accuracy: 0.7292
Epoch 7/150, Loss: 0.5178, Accuracy: 0.7576, Val Loss: 0.5741, Val Accuracy: 0.7222
Epoch 8/150, Loss: 0.4472, Accuracy: 0.7889, Val Loss: 0.5425, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.4553, Accuracy: 0.8010, Val Loss: 0.4941, Val Accuracy: 0.7743
Epoch 10/150, Loss: 0.4355, Accuracy: 0.8050, Val Loss: 0.5327, Val Accuracy: 0.7465
Epoch 11/150, Loss: 0.4010, Accuracy: 0.8297, Val Loss: 0.4917, Val Accuracy: 0.7639
Epoch 12/150, Loss: 0.3926, Accuracy: 0.8375, Val Loss: 0.5312, Val Accura



Epoch 1/150, Loss: 0.5265, Accuracy: 0.7354, Val Loss: 0.5668, Val Accuracy: 0.6910
Epoch 2/150, Loss: 0.5135, Accuracy: 0.7541, Val Loss: 0.5545, Val Accuracy: 0.6736
Epoch 3/150, Loss: 0.5232, Accuracy: 0.7485, Val Loss: 0.4957, Val Accuracy: 0.7708
Epoch 4/150, Loss: 0.5062, Accuracy: 0.7659, Val Loss: 0.5008, Val Accuracy: 0.7708
Epoch 5/150, Loss: 0.4695, Accuracy: 0.7967, Val Loss: 0.5752, Val Accuracy: 0.7431
Epoch 6/150, Loss: 0.5079, Accuracy: 0.7728, Val Loss: 0.5231, Val Accuracy: 0.7674
Epoch 7/150, Loss: 0.4968, Accuracy: 0.7654, Val Loss: 0.5069, Val Accuracy: 0.7535
Epoch 8/150, Loss: 0.4834, Accuracy: 0.7815, Val Loss: 0.5703, Val Accuracy: 0.7431
Epoch 9/150, Loss: 0.4974, Accuracy: 0.7798, Val Loss: 0.5246, Val Accuracy: 0.7639
Epoch 10/150, Loss: 0.4682, Accuracy: 0.7945, Val Loss: 0.4956, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4598, Accuracy: 0.7871, Val Loss: 0.5131, Val Accuracy: 0.7535
Epoch 12/150, Loss: 0.4303, Accuracy: 0.8084, Val Loss: 0.5395, Val Accura



Epoch 1/150, Loss: 0.5542, Accuracy: 0.7029, Val Loss: 0.4831, Val Accuracy: 0.7604
Epoch 2/150, Loss: 0.5312, Accuracy: 0.7428, Val Loss: 0.5105, Val Accuracy: 0.7674
Epoch 3/150, Loss: 0.5323, Accuracy: 0.7337, Val Loss: 0.5231, Val Accuracy: 0.7118
Epoch 4/150, Loss: 0.5076, Accuracy: 0.7493, Val Loss: 0.5349, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.4967, Accuracy: 0.7585, Val Loss: 0.4933, Val Accuracy: 0.7639
Epoch 6/150, Loss: 0.5240, Accuracy: 0.7472, Val Loss: 0.5222, Val Accuracy: 0.7361
Epoch 7/150, Loss: 0.5345, Accuracy: 0.7289, Val Loss: 0.5353, Val Accuracy: 0.7118
Epoch 8/150, Loss: 0.5006, Accuracy: 0.7624, Val Loss: 0.5124, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.5029, Accuracy: 0.7637, Val Loss: 0.5148, Val Accuracy: 0.7465
Epoch 10/150, Loss: 0.5020, Accuracy: 0.7585, Val Loss: 0.5123, Val Accuracy: 0.7535
Epoch 11/150, Loss: 0.4690, Accuracy: 0.7806, Val Loss: 0.4996, Val Accuracy: 0.7639
Early stopping at epoch 11
Total parameters: 20960706
FLOPs: 20103296




Fold 2, Model Loss: 0.4690
Epoch 1/150, Loss: 0.5578, Accuracy: 0.7233, Val Loss: 0.5049, Val Accuracy: 0.7569
Epoch 2/150, Loss: 0.5291, Accuracy: 0.7394, Val Loss: 0.5017, Val Accuracy: 0.7674
Epoch 3/150, Loss: 0.5324, Accuracy: 0.7315, Val Loss: 0.4969, Val Accuracy: 0.7708
Epoch 4/150, Loss: 0.5351, Accuracy: 0.7333, Val Loss: 0.4938, Val Accuracy: 0.7743
Epoch 5/150, Loss: 0.5229, Accuracy: 0.7376, Val Loss: 0.5386, Val Accuracy: 0.7361
Epoch 6/150, Loss: 0.5327, Accuracy: 0.7450, Val Loss: 0.5181, Val Accuracy: 0.7396
Epoch 7/150, Loss: 0.5262, Accuracy: 0.7511, Val Loss: 0.5353, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.5243, Accuracy: 0.7593, Val Loss: 0.5268, Val Accuracy: 0.7674
Epoch 9/150, Loss: 0.5782, Accuracy: 0.6577, Val Loss: 0.6048, Val Accuracy: 0.5972
Epoch 10/150, Loss: 0.5670, Accuracy: 0.6937, Val Loss: 0.5291, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.5478, Accuracy: 0.7085, Val Loss: 0.5036, Val Accuracy: 0.7118
Epoch 12/150, Loss: 0.5145, Accuracy: 0.7606, V



Epoch 1/150, Loss: 0.5210, Accuracy: 0.7326, Val Loss: 0.4867, Val Accuracy: 0.7465
Epoch 2/150, Loss: 0.5030, Accuracy: 0.7426, Val Loss: 0.5213, Val Accuracy: 0.7674
Epoch 3/150, Loss: 0.4876, Accuracy: 0.7543, Val Loss: 0.4763, Val Accuracy: 0.7674
Epoch 4/150, Loss: 0.4809, Accuracy: 0.7535, Val Loss: 0.4698, Val Accuracy: 0.7604
Epoch 5/150, Loss: 0.4928, Accuracy: 0.7457, Val Loss: 0.4587, Val Accuracy: 0.7951
Epoch 6/150, Loss: 0.4870, Accuracy: 0.7439, Val Loss: 0.5169, Val Accuracy: 0.7153
Epoch 7/150, Loss: 0.4686, Accuracy: 0.7565, Val Loss: 0.4512, Val Accuracy: 0.7708
Epoch 8/150, Loss: 0.4895, Accuracy: 0.7431, Val Loss: 0.4692, Val Accuracy: 0.7639
Epoch 9/150, Loss: 0.4644, Accuracy: 0.7626, Val Loss: 0.5169, Val Accuracy: 0.7326
Epoch 10/150, Loss: 0.4687, Accuracy: 0.7622, Val Loss: 0.4808, Val Accuracy: 0.7500
Epoch 11/150, Loss: 0.4774, Accuracy: 0.7522, Val Loss: 0.4561, Val Accuracy: 0.7743
Epoch 12/150, Loss: 0.4865, Accuracy: 0.7487, Val Loss: 0.4758, Val Accura



Epoch 1/150, Loss: 0.5099, Accuracy: 0.7396, Val Loss: 0.4725, Val Accuracy: 0.7431
Epoch 2/150, Loss: 0.4989, Accuracy: 0.7344, Val Loss: 0.4668, Val Accuracy: 0.7847
Epoch 3/150, Loss: 0.4738, Accuracy: 0.7687, Val Loss: 0.4803, Val Accuracy: 0.7535
Epoch 4/150, Loss: 0.4900, Accuracy: 0.7461, Val Loss: 0.5237, Val Accuracy: 0.7361
Epoch 5/150, Loss: 0.4753, Accuracy: 0.7695, Val Loss: 0.4340, Val Accuracy: 0.7882
Epoch 6/150, Loss: 0.4651, Accuracy: 0.7687, Val Loss: 0.4630, Val Accuracy: 0.7708
Epoch 7/150, Loss: 0.4694, Accuracy: 0.7556, Val Loss: 0.4878, Val Accuracy: 0.7500
Epoch 8/150, Loss: 0.4679, Accuracy: 0.7591, Val Loss: 0.5231, Val Accuracy: 0.7361
Epoch 9/150, Loss: 0.4634, Accuracy: 0.7717, Val Loss: 0.4585, Val Accuracy: 0.7708
Epoch 10/150, Loss: 0.4552, Accuracy: 0.7743, Val Loss: 0.5218, Val Accuracy: 0.7083
Epoch 11/150, Loss: 0.4655, Accuracy: 0.7626, Val Loss: 0.4806, Val Accuracy: 0.7326
Epoch 12/150, Loss: 0.4317, Accuracy: 0.8012, Val Loss: 0.4891, Val Accura



Epoch 1/150, Loss: 0.5409, Accuracy: 0.7040, Val Loss: 0.4992, Val Accuracy: 0.7257
Epoch 2/150, Loss: 0.4917, Accuracy: 0.7418, Val Loss: 0.4622, Val Accuracy: 0.7708
Epoch 3/150, Loss: 0.4946, Accuracy: 0.7322, Val Loss: 0.4881, Val Accuracy: 0.7535
Epoch 4/150, Loss: 0.4808, Accuracy: 0.7539, Val Loss: 0.4529, Val Accuracy: 0.7569
Epoch 5/150, Loss: 0.4849, Accuracy: 0.7535, Val Loss: 0.4776, Val Accuracy: 0.7674
Epoch 6/150, Loss: 0.4996, Accuracy: 0.7387, Val Loss: 0.4794, Val Accuracy: 0.7951
Epoch 7/150, Loss: 0.4977, Accuracy: 0.7396, Val Loss: 0.4490, Val Accuracy: 0.7917
Epoch 8/150, Loss: 0.4901, Accuracy: 0.7630, Val Loss: 0.4664, Val Accuracy: 0.7986
Epoch 9/150, Loss: 0.4903, Accuracy: 0.7539, Val Loss: 0.4993, Val Accuracy: 0.7743
Epoch 10/150, Loss: 0.4886, Accuracy: 0.7539, Val Loss: 0.4722, Val Accuracy: 0.7604
Epoch 11/150, Loss: 0.4784, Accuracy: 0.7465, Val Loss: 0.4281, Val Accuracy: 0.7917
Epoch 12/150, Loss: 0.4905, Accuracy: 0.7517, Val Loss: 0.4846, Val Accura



Epoch 1/150, Loss: 0.5223, Accuracy: 0.7322, Val Loss: 0.4824, Val Accuracy: 0.7604
Epoch 2/150, Loss: 0.5087, Accuracy: 0.7274, Val Loss: 0.4815, Val Accuracy: 0.7431
Epoch 3/150, Loss: 0.4909, Accuracy: 0.7400, Val Loss: 0.4574, Val Accuracy: 0.7951
Epoch 4/150, Loss: 0.4743, Accuracy: 0.7635, Val Loss: 0.4812, Val Accuracy: 0.7465
Epoch 5/150, Loss: 0.4771, Accuracy: 0.7539, Val Loss: 0.4718, Val Accuracy: 0.7500
Epoch 6/150, Loss: 0.4674, Accuracy: 0.7496, Val Loss: 0.4880, Val Accuracy: 0.7361
Epoch 7/150, Loss: 0.4753, Accuracy: 0.7517, Val Loss: 0.4809, Val Accuracy: 0.7326
Epoch 8/150, Loss: 0.4562, Accuracy: 0.7630, Val Loss: 0.4652, Val Accuracy: 0.7743
Epoch 9/150, Loss: 0.4583, Accuracy: 0.7595, Val Loss: 0.4925, Val Accuracy: 0.7431
Epoch 10/150, Loss: 0.4673, Accuracy: 0.7656, Val Loss: 0.4641, Val Accuracy: 0.7431
Epoch 11/150, Loss: 0.4424, Accuracy: 0.7808, Val Loss: 0.4824, Val Accuracy: 0.7396
Epoch 12/150, Loss: 0.4192, Accuracy: 0.7977, Val Loss: 0.4891, Val Accura



FLOPs: 20103296
Fold 3, Model Loss: 0.4408
Epoch 1/150, Loss: 0.5282, Accuracy: 0.7279, Val Loss: 0.4665, Val Accuracy: 0.7778
Epoch 2/150, Loss: 0.4976, Accuracy: 0.7422, Val Loss: 0.4860, Val Accuracy: 0.7465
Epoch 3/150, Loss: 0.4952, Accuracy: 0.7283, Val Loss: 0.4843, Val Accuracy: 0.7153
Epoch 4/150, Loss: 0.4851, Accuracy: 0.7513, Val Loss: 0.5246, Val Accuracy: 0.7153
Epoch 5/150, Loss: 0.5075, Accuracy: 0.7318, Val Loss: 0.4690, Val Accuracy: 0.7431
Epoch 6/150, Loss: 0.4967, Accuracy: 0.7426, Val Loss: 0.4949, Val Accuracy: 0.7500
Epoch 7/150, Loss: 0.4814, Accuracy: 0.7561, Val Loss: 0.4903, Val Accuracy: 0.7569
Epoch 8/150, Loss: 0.4674, Accuracy: 0.7674, Val Loss: 0.4559, Val Accuracy: 0.7812
Epoch 9/150, Loss: 0.4561, Accuracy: 0.7708, Val Loss: 0.4533, Val Accuracy: 0.7812
Epoch 10/150, Loss: 0.4592, Accuracy: 0.7652, Val Loss: 0.4714, Val Accuracy: 0.7674
Epoch 11/150, Loss: 0.4531, Accuracy: 0.7695, Val Loss: 0.4502, Val Accuracy: 0.7951
Epoch 12/150, Loss: 0.4433, Acc



Epoch 1/150, Loss: 0.5868, Accuracy: 0.6903, Val Loss: 0.6128, Val Accuracy: 0.6528
Epoch 2/150, Loss: 0.5552, Accuracy: 0.7050, Val Loss: 0.5928, Val Accuracy: 0.6840
Epoch 3/150, Loss: 0.5627, Accuracy: 0.7176, Val Loss: 0.5946, Val Accuracy: 0.6806
Epoch 4/150, Loss: 0.5660, Accuracy: 0.7046, Val Loss: 0.6365, Val Accuracy: 0.6597
Epoch 5/150, Loss: 0.5633, Accuracy: 0.7107, Val Loss: 0.6151, Val Accuracy: 0.6076
Epoch 6/150, Loss: 0.5766, Accuracy: 0.6725, Val Loss: 0.6086, Val Accuracy: 0.6910
Epoch 7/150, Loss: 0.5782, Accuracy: 0.6707, Val Loss: 0.6083, Val Accuracy: 0.6562
Epoch 8/150, Loss: 0.5573, Accuracy: 0.6977, Val Loss: 0.5852, Val Accuracy: 0.7014
Epoch 9/150, Loss: 0.5475, Accuracy: 0.7137, Val Loss: 0.5881, Val Accuracy: 0.7049
Epoch 10/150, Loss: 0.5642, Accuracy: 0.7011, Val Loss: 0.6218, Val Accuracy: 0.6667
Epoch 11/150, Loss: 0.5606, Accuracy: 0.7042, Val Loss: 0.6032, Val Accuracy: 0.6910
Epoch 12/150, Loss: 0.5631, Accuracy: 0.6903, Val Loss: 0.6130, Val Accura



Epoch 1/150, Loss: 0.5684, Accuracy: 0.6946, Val Loss: 0.6110, Val Accuracy: 0.6701
Epoch 2/150, Loss: 0.5691, Accuracy: 0.6972, Val Loss: 0.5972, Val Accuracy: 0.6562
Epoch 3/150, Loss: 0.5616, Accuracy: 0.7107, Val Loss: 0.5556, Val Accuracy: 0.7118
Epoch 4/150, Loss: 0.5646, Accuracy: 0.6968, Val Loss: 0.6002, Val Accuracy: 0.6736
Epoch 5/150, Loss: 0.5635, Accuracy: 0.7011, Val Loss: 0.5693, Val Accuracy: 0.7118
Epoch 6/150, Loss: 0.5570, Accuracy: 0.7142, Val Loss: 0.5536, Val Accuracy: 0.7083
Epoch 7/150, Loss: 0.5643, Accuracy: 0.6994, Val Loss: 0.5821, Val Accuracy: 0.6944
Epoch 8/150, Loss: 0.5509, Accuracy: 0.7142, Val Loss: 0.5762, Val Accuracy: 0.7222
Epoch 9/150, Loss: 0.5338, Accuracy: 0.7259, Val Loss: 0.5812, Val Accuracy: 0.6979
Epoch 10/150, Loss: 0.5278, Accuracy: 0.7389, Val Loss: 0.6071, Val Accuracy: 0.6736
Epoch 11/150, Loss: 0.5382, Accuracy: 0.7250, Val Loss: 0.5704, Val Accuracy: 0.7083
Epoch 12/150, Loss: 0.5369, Accuracy: 0.7433, Val Loss: 0.6639, Val Accura



Epoch 1/150, Loss: 0.5914, Accuracy: 0.6746, Val Loss: 0.5836, Val Accuracy: 0.7118
Epoch 2/150, Loss: 0.5715, Accuracy: 0.7103, Val Loss: 0.5982, Val Accuracy: 0.6771
Epoch 3/150, Loss: 0.5552, Accuracy: 0.7172, Val Loss: 0.6332, Val Accuracy: 0.6562
Epoch 4/150, Loss: 0.5516, Accuracy: 0.7237, Val Loss: 0.5769, Val Accuracy: 0.7014
Epoch 5/150, Loss: 0.5391, Accuracy: 0.7289, Val Loss: 0.5897, Val Accuracy: 0.6979
Epoch 6/150, Loss: 0.5596, Accuracy: 0.7189, Val Loss: 0.6673, Val Accuracy: 0.6528
Epoch 7/150, Loss: 0.5591, Accuracy: 0.7133, Val Loss: 0.6364, Val Accuracy: 0.6493
Epoch 8/150, Loss: 0.5498, Accuracy: 0.7159, Val Loss: 0.6033, Val Accuracy: 0.6840
Epoch 9/150, Loss: 0.5524, Accuracy: 0.7050, Val Loss: 0.6242, Val Accuracy: 0.5972
Epoch 10/150, Loss: 0.5460, Accuracy: 0.7168, Val Loss: 0.6136, Val Accuracy: 0.6597
Epoch 11/150, Loss: 0.5167, Accuracy: 0.7428, Val Loss: 0.5988, Val Accuracy: 0.6771
Epoch 12/150, Loss: 0.5118, Accuracy: 0.7459, Val Loss: 0.6288, Val Accura



Epoch 1/150, Loss: 0.5910, Accuracy: 0.6751, Val Loss: 0.5941, Val Accuracy: 0.6875
Epoch 2/150, Loss: 0.5630, Accuracy: 0.6985, Val Loss: 0.5850, Val Accuracy: 0.7083
Epoch 3/150, Loss: 0.5467, Accuracy: 0.7237, Val Loss: 0.6223, Val Accuracy: 0.6771
Epoch 4/150, Loss: 0.5761, Accuracy: 0.6742, Val Loss: 0.6278, Val Accuracy: 0.6389
Epoch 5/150, Loss: 0.5644, Accuracy: 0.6998, Val Loss: 0.5679, Val Accuracy: 0.7188
Epoch 6/150, Loss: 0.5547, Accuracy: 0.7046, Val Loss: 0.6323, Val Accuracy: 0.7014
Epoch 7/150, Loss: 0.5589, Accuracy: 0.7259, Val Loss: 0.5799, Val Accuracy: 0.7014
Epoch 8/150, Loss: 0.5370, Accuracy: 0.7276, Val Loss: 0.6184, Val Accuracy: 0.6632
Epoch 9/150, Loss: 0.5544, Accuracy: 0.6994, Val Loss: 0.5834, Val Accuracy: 0.6840
Epoch 10/150, Loss: 0.5730, Accuracy: 0.6829, Val Loss: 0.6134, Val Accuracy: 0.6597
Epoch 11/150, Loss: 0.5750, Accuracy: 0.6903, Val Loss: 0.5943, Val Accuracy: 0.6910
Epoch 12/150, Loss: 0.5287, Accuracy: 0.7350, Val Loss: 0.5769, Val Accura



FLOPs: 20103296
Fold 4, Model Loss: 0.4894
Epoch 1/150, Loss: 0.5907, Accuracy: 0.6720, Val Loss: 0.5972, Val Accuracy: 0.6701
Epoch 2/150, Loss: 0.5644, Accuracy: 0.6972, Val Loss: 0.5969, Val Accuracy: 0.6806
Epoch 3/150, Loss: 0.5621, Accuracy: 0.7063, Val Loss: 0.6081, Val Accuracy: 0.6528
Epoch 4/150, Loss: 0.5396, Accuracy: 0.7255, Val Loss: 0.5943, Val Accuracy: 0.7222
Epoch 5/150, Loss: 0.5522, Accuracy: 0.7211, Val Loss: 0.5731, Val Accuracy: 0.7049
Epoch 6/150, Loss: 0.5453, Accuracy: 0.7176, Val Loss: 0.5896, Val Accuracy: 0.6910
Epoch 7/150, Loss: 0.5401, Accuracy: 0.7185, Val Loss: 0.5830, Val Accuracy: 0.7118
Epoch 8/150, Loss: 0.5294, Accuracy: 0.7307, Val Loss: 0.5791, Val Accuracy: 0.7049
Epoch 9/150, Loss: 0.5517, Accuracy: 0.7189, Val Loss: 0.5793, Val Accuracy: 0.7153
Epoch 10/150, Loss: 0.5783, Accuracy: 0.6855, Val Loss: 0.5784, Val Accuracy: 0.7222
Epoch 11/150, Loss: 0.5617, Accuracy: 0.7055, Val Loss: 0.6136, Val Accuracy: 0.6597
Epoch 12/150, Loss: 0.5401, Acc



Mean Accuracy: 70.71%, (SD=0.09926880481726288)
Mean Loss: 0.6857, (SD=0.16867505020405527)
Mean Precision: 0.7110, (SD=0.09858764374039526)
Mean Recall: 0.7071, (SD=0.09926880481726288)
Mean F1 Score: 0.6972, (SD=0.10775521000269046)
