<a href="https://colab.research.google.com/github/Rito43/Rito43/blob/main/M23CSA021_Ass_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code is useful in data loading, visualization and exploration. You are free to modify the code. The code has dependecy on Pytorch Lightning data module. However, you may use Pytorch as well.

In [None]:
from google.colab import drive

# Unmount Google Drive
drive.flush_and_unmount()

In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

**Introduction to Dataset**

The data has a total of 10 classes with 40 samples each. Make sure while working with the data, **esc10=True**. In the assignment, you are required to perform 4-fold validation. This dataset has been already divided into 5-folds. The column 'fold' in the metafile denotes the sample in a particular fold. Moreover, first folds is considered for test, rest for 4-fold validation.

In [None]:
# DL Assignment 2
# Authors: Kopal Rastogi, Ishan Mishra
# Keywords: None
# Assumptions: None

In [None]:
# Installing the requirements
print('Installing Requirements... ',end='')
!pip install lightning
print('Done')

In [None]:
# Extract data
#with zipfile.ZipFile("/content/master.zip", 'r') as zip_ref:
    #zip_ref.extractall("/content/")

In [None]:
# Importing Libraries
print('Importing Libraries... ',end='')
import os
from pathlib import Path
import pandas as pd
import torchaudio
import zipfile
from torchaudio.transforms import Resample
import IPython.display as ipd
from matplotlib import pyplot as plt
from tqdm import tqdm
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import torch
print('Done')

In [None]:
# Download data
print('Downlading data... ', end='')
# Your code here
print('Done')

In [None]:
zip_path = '/content/drive/MyDrive/Archive.zip'
extract_path = '/content/drive/MyDrive/'

In [None]:
import os
drive_path = '/content/drive/MyDrive/'
file_name = 'Archive.zip'
file_path = os.path.join(drive_path, file_name)

if os.path.exists(file_path):
    print("File exists at:", file_path)
else:
    print("File not found:", file_path)

In [None]:
# Extract data
import zipfile
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

In [None]:
# Loading dataset
path = Path('/content/drive/MyDrive/')
df = pd.read_csv('/content/drive/MyDrive/meta/esc50.csv')

In [None]:
# Getting list of raw audio files
wavs = list(path.glob('audio/*'))  # List all audio files in the 'audio' directory using pathlib.Path.glob

# Visualizing data
waveform, sample_rate = torchaudio.load(wavs[0])  # Load the waveform and sample rate of the first audio file using torchaudio

print("Shape of waveform: {}".format(waveform.size()))  # Print the shape of the waveform tensor
print("Sample rate of waveform: {}".format(sample_rate))  # Print the sample rate of the audio file

# Plot the waveform using matplotlib
plt.figure()
plt.plot(waveform.t().numpy())  # Transpose and convert the waveform tensor to a NumPy array for plotting

# Display the audio using IPython.display.Audio
ipd.Audio(waveform, rate=sample_rate)  # Create an interactive audio player for the loaded waveform


In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, **kwargs):
        # Initialize CustomDataset object with relevant parameters
        # dataset: "train", "val", or "test"
        # kwargs: Additional parameters like data directory, dataframe, folds, etc.

        # Extract parameters from kwargs
        self.data_directory = kwargs["data_directory"]
        self.data_frame = kwargs["data_frame"]
        self.validation_fold = kwargs["validation_fold"]
        self.testing_fold = kwargs["testing_fold"]
        self.esc_10_flag = kwargs["esc_10_flag"]
        self.file_column = kwargs["file_column"]
        self.label_column = kwargs["label_column"]
        self.sampling_rate = kwargs["sampling_rate"]
        self.new_sampling_rate = kwargs["new_sampling_rate"]
        self.sample_length_seconds = kwargs["sample_length_seconds"]

        # Filter dataframe based on esc_10_flag and data_type
        if self.esc_10_flag:
            self.data_frame = self.data_frame.loc[self.data_frame['esc10'] == True]

        if dataset == "train":
            self.data_frame = self.data_frame.loc[
                (self.data_frame['fold'] != self.validation_fold) & (self.data_frame['fold'] != self.testing_fold)]
        elif dataset == "val":
            self.data_frame = self.data_frame.loc[self.data_frame['fold'] == self.validation_fold]
        elif dataset == "test":
            self.data_frame = self.data_frame.loc[self.data_frame['fold'] == self.testing_fold]

        # Get unique categories from the filtered dataframe
        self.categories = sorted(self.data_frame[self.label_column].unique())

        # Initialize lists to hold file names, labels, and folder numbers
        self.file_names = []
        self.labels = []

        # Initialize dictionaries for category-to-index and index-to-category mapping
        self.category_to_index = {}
        self.index_to_category = {}

        for i, category in enumerate(self.categories):
            self.category_to_index[category] = i
            self.index_to_category[i] = category

        # Populate file names and labels lists by iterating through the dataframe
        for ind in tqdm(range(len(self.data_frame))):
            row = self.data_frame.iloc[ind]
            file_path = self.data_directory / "audio" / row[self.file_column]
            self.file_names.append(file_path)
            self.labels.append(self.category_to_index[row[self.label_column]])

        self.resampler = torchaudio.transforms.Resample(self.sampling_rate, self.new_sampling_rate)

        # Window size for rolling window sample splits (unfold method)
        if self.sample_length_seconds == 2:
            self.window_size = self.new_sampling_rate * 2
            self.step_size = int(self.new_sampling_rate * 0.75)
        else:
            self.window_size = self.new_sampling_rate
            self.step_size = int(self.new_sampling_rate * 0.5)

    def __getitem__(self, index):
        # Split audio files with overlap, pass as stacked tensors tensor with a single label
        path = self.file_names[index]
        audio_file = torchaudio.load(path, format=None, normalize=True)
        audio_tensor = self.resampler(audio_file[0])
        splits = audio_tensor.unfold(1, self.window_size, self.step_size)
        samples = splits.permute(1, 0, 2)
        return samples, self.labels[index]

    def __len__(self):
        return len(self.file_names)


In [None]:
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, **kwargs):
        # Initialize the CustomDataModule with batch size, number of workers, and other parameters
        super().__init__()
        self.batch_size = kwargs["batch_size"]
        self.num_workers = kwargs["num_workers"]
        self.data_module_kwargs = kwargs

    def setup(self, stage=None):
        # Define datasets for training, validation, and testing during Lightning setup

        # If in 'fit' or None stage, create training and validation datasets
        if stage == 'fit' or stage is None:
            self.training_dataset = CustomDataset(dataset="train", **self.data_module_kwargs)
            self.validation_dataset = CustomDataset(dataset="val", **self.data_module_kwargs)

        # If in 'test' or None stage, create testing dataset
        if stage == 'test' or stage is None:
            self.testing_dataset = CustomDataset(dataset="test", **self.data_module_kwargs)

    def train_dataloader(self):
        # Return DataLoader for training dataset
        return DataLoader(self.training_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          collate_fn=self.collate_function,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        # Return DataLoader for validation dataset
        return DataLoader(self.validation_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          collate_fn=self.collate_function,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        # Return DataLoader for testing dataset
        return DataLoader(self.testing_dataset,
                          batch_size=32,
                          shuffle=False,
                          collate_fn=self.collate_function,
                          num_workers=self.num_workers)

    def collate_function(self, data):
        """
        Collate function to process a batch of examples and labels.

        Args:
            data: a tuple of 2 tuples with (example, label) where
                example are the split 1 second sub-frame audio tensors per file
                label = the label

        Returns:
            A list containing examples (concatenated tensors) and labels (flattened tensor).
        """
        examples, labels = zip(*data)
        examples = torch.stack(examples)
        examples = examples.reshape(examples.size(0),1,-1)
        labels = torch.flatten(torch.tensor(labels))

        return [examples, labels]


In [None]:
# Data Setup

test_samp = 1  # Do not change this!!
valid_samp = 2 # Use any value ranging from 2 to 5 for k-fold validation (valid_fold)
batch_size = 32 # Free to change
num_workers = 2 # Free to change
custom_data_module = CustomDataModule(batch_size=batch_size,
                                      num_workers=num_workers,
                                      data_directory=path,
                                      data_frame=df,
                                      validation_fold=valid_samp,
                                      testing_fold=test_samp,  # set to 0 for no test set
                                      esc_10_flag=True,
                                      file_column='filename',
                                      label_column='category',
                                      sampling_rate=44100,
                                      new_sampling_rate=16000,  # new sample rate for input
                                      sample_length_seconds=1  # new length of input in seconds
                                      )

custom_data_module.setup()


In [None]:
# Data Exploration
print('Class Label: ', custom_data_module.training_dataset[0][1])  # this prints the class label
print('Shape of data sample tensor: ', custom_data_module.training_dataset[0][0].shape)  # this prints the shape of the sample (Frames, Channel, Features)


In [None]:
# Dataloader(s)
x = next(iter(custom_data_module.train_dataloader()))
y = next(iter(custom_data_module.val_dataloader()))
z = next(iter(custom_data_module.test_dataloader()))
print('Train Dataloader:')
print(x)
print('Validation Dataloader:')
print(y)
print('Test Dataloader:')
print(z)


In [None]:
!pip install wandb

In [None]:
!pip install pytorch_lightning



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=16, stride=1, padding=8)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=16, stride=1, padding=8)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool1d(self.conv1(x), 2))
        x = F.relu(F.max_pool1d(self.conv2(x), 2))
        x = F.relu(self.conv3(x))
        x = self.avg_pool(x)
        x = x.view(-1, 128)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class CNNModule(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', val_loss)
        return val_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


num_classes = 10
cnn_model = CNNModel(num_classes)
cnn_module = CNNModule(cnn_model)


wandb_logger = pl.loggers.WandbLogger(project='your_project_name', log_model=True)


trainer = pl.Trainer(max_epochs=3, logger=wandb_logger)


trainer.fit(cnn_module, custom_data_module)

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_auc_score
import seaborn as sns
import matplotlib.pyplot as plt
import itertools


k = 4
fold_results = []

for fold in range(k):
    print(f"Fold {fold + 1}/{k}")
    custom_data_module = CustomDataModule(batch_size=batch_size,
                                          num_workers=num_workers,
                                          data_directory=path,
                                          data_frame=df,
                                          validation_fold=(fold + 1),
                                          testing_fold=test_samp,
                                          esc_10_flag=True,
                                          file_column='filename',
                                          label_column='category',
                                          sampling_rate=44100,
                                          new_sampling_rate=16000,
                                          sample_length_seconds=1)

    custom_data_module.setup()


    trainer.fit(cnn_module, custom_data_module)


    result = trainer.test(cnn_module, custom_data_module.test_dataloader())
    fold_results.append(result)


test_results = {}
for i, result in enumerate(fold_results):
    test_results[f'Fold_{i+1}'] = result[0]


test_loss = np.mean([result['test_loss'] for result in fold_results])
print("Overall Test Loss:", test_loss)


all_preds = []
all_labels = []
for fold in range(k):
    custom_data_module = CustomDataModule(batch_size=batch_size,
                                          num_workers=num_workers,
                                          data_directory=path,
                                          data_frame=df,
                                          validation_fold=(fold + 1),
                                          testing_fold=test_samp,
                                          esc_10_flag=True,
                                          file_column='filename',
                                          label_column='category',
                                          sampling_rate=44100,
                                          new_sampling_rate=16000,
                                          sample_length_seconds=1)
    custom_data_module.setup()
    preds = trainer.predict(cnn_module, custom_data_module.test_dataloader())
    preds = np.concatenate([p.cpu().numpy() for p in preds])
    true_labels = np.concatenate([labels.cpu().numpy() for _, labels in custom_data_module.test_dataloader()])
    all_preds.append(preds)
    all_labels.append(true_labels)

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)


accuracy = accuracy_score(all_labels, all_preds)
print("Accuracy:", accuracy)


cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=custom_data_module.categories, yticklabels=custom_data_module.categories)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()


report = classification_report(all_labels, all_preds, target_names=custom_data_module.categories)
print("Classification Report:")
print(report)


f1_scores = classification_report(all_labels, all_preds, target_names=custom_data_module.categories, output_dict=True)['f1-score']
print("F1-Scores:")
print(f1_scores)


auc_roc_score = roc_auc_score(pd.get_dummies(all_labels), pd.get_dummies(all_preds), average='macro')
print("AUC-ROC Score:", auc_roc_score)


In [None]:
total_params = sum(p.numel() for p in cnn_module.parameters())
trainable_params = sum(p.numel() for p in cnn_module.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params
print("Total Trainable Parameters:", trainable_params)
print("Total Non-Trainable Parameters:", non_trainable_params)


lr_values = [1e-2, 1e-3, 1e-4]
best_lr = None
best_accuracy = 0.0

for lr in lr_values:

    wandb_logger = pl.loggers.WandbLogger(project='your_project_name', log_model=True, name=f'lr_{lr}')


    trainer = pl.Trainer(max_epochs=3, logger=wandb_logger)


    cnn_module.lr = lr


    trainer.fit(cnn_module, custom_data_module)


    result = trainer.test(cnn_module, custom_data_module.test_dataloader())


    accuracy = result[0]['test_acc']


    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_lr = lr

print("Best Learning Rate:", best_lr)
print("Best Accuracy:", best_accuracy)

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, input_size, num_heads=1):
        super(TransformerEncoderBlock, self).__init__()
        self.self_attention = nn.MultiheadAttention(input_size, num_heads)
        self.norm = nn.LayerNorm(input_size)
        self.mlp = nn.Sequential(
            nn.Linear(input_size, 4 * input_size),
            nn.ReLU(),
            nn.Linear(4 * input_size, input_size)
        )

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        x = self.norm(x + attn_output)
        mlp_output = self.mlp(x)
        x = self.norm(x + mlp_output)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, input_size, num_heads=1, num_blocks=2):
        super(TransformerEncoder, self).__init__()
        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(input_size, num_heads) for _ in range(num_blocks)
        ])

    def forward(self, x):
        for block in self.encoder_blocks:
            x = block(x)
        return x


class ConvTransformer(pl.LightningModule):
    def __init__(self, num_classes, input_channels, input_length, num_heads=[1, 2, 4]):
        super(ConvTransformer, self).__init__()
        self.cnn_base = CNNModel(num_classes)
        self.transformer_heads = nn.ModuleList([
            TransformerEncoder(input_size=256, num_heads=num_head) for num_head in num_heads
        ])
        self.mlp_head = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.cnn_base(x)
        for transformer_head in self.transformer_heads:
            x = transformer_head(x)
        x = self.mlp_head(x)
        return x


class ConvTransformerModule(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', val_loss)
        return val_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


def evaluate_on_test_set(model, test_dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in test_dataloader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    auc_roc = roc_auc_score(all_labels, all_preds, average='weighted', multi_class='ovr')
    cm = confusion_matrix(all_labels, all_preds)

    return accuracy, f1, auc_roc, cm


batch_size = 32
num_workers = 8
custom_data_module = CustomDataModule(batch_size=batch_size,
                                      num_workers=num_workers,
                                      data_directory=path,
                                      data_frame=df,
                                      validation_fold=1,
                                      testing_fold=2,
                                      esc_10_flag=True,
                                      file_column='filename',
                                      label_column='category',
                                      sampling_rate=44100,
                                      new_sampling_rate=16000,
                                      sample_length_seconds=1
                                      )


kfolds = 4
kf = KFold(n_splits=kfolds, shuffle=True)
test_accuracies = []
test_f1_scores = []
test_auc_rocs = []
confusion_matrices = []

for fold, (train_idx, val_idx) in enumerate(kf.split(custom_data_module)):
    print(f"Fold {fold+1}/{kfolds}")


    train_dataset = torch.utils.data.Subset(custom_data_module, train_idx)
    val_dataset = torch.utils.data.Subset(custom_data_module, val_idx)


    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)


    conv_transformer_model = ConvTransformer(num_classes, input_channels=1, input_length=16000)
    conv_transformer_module = ConvTransformerModule(conv_transformer_model)
    trainer = pl.Trainer(max_epochs=3, logger=wandb_logger)


    trainer.fit(conv_transformer_module, train_dataloader, val_dataloader)


    test_dataloader = custom_data_module.test_dataloader()
    accuracy, f1, auc_roc, cm = evaluate_on_test_set(conv_transformer_model, test_dataloader)


    test_accuracies.append(accuracy)
    test_f1_scores.append(f1)
    test_auc_rocs.append(auc_roc)
    confusion_matrices.append(cm)


mean_accuracy = np.mean(test_accuracies)
std_accuracy = np.std(test_accuracies)
mean_f1 = np.mean(test_f1_scores)
std_f1 = np.std(test_f1_scores)
mean_auc_roc = np.mean(test_auc_rocs)
std_auc_roc = np.std(test_auc_rocs)

print(f"Mean Accuracy: {mean_accuracy} ± {std_accuracy}")
print(f"Mean F1 Score: {mean_f1} ± {std_f1}")
print(f"Mean AUC-ROC Score: {mean_auc_roc} ± {std_auc_roc}")


mean_cm = np.mean(confusion_matrices, axis=0)
plt.figure(figsize=(10, 8))
sns.heatmap(mean_cm, annot=True, fmt="d", cmap="Blues", xticklabels=custom_data_module.testing_dataset.categories, yticklabels=custom_data_module.testing_dataset.categories)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Mean Confusion Matrix")
plt.show()

In [None]:

conv_transformer_model = ConvTransformer(num_classes, input_channels=1, input_length=16000)


trainable_params = sum(p.numel() for p in conv_transformer_model.parameters() if p.requires_grad)


non_trainable_params = sum(p.numel() for p in conv_transformer_model.parameters() if not p.requires_grad)

print(f"Total Trainable Parameters: {trainable_params}")
print(f"Total Non-Trainable Parameters: {non_trainable_params}")


In [None]:
import itertools


learning_rates = [1e-3, 5e-4, 1e-4]
batch_sizes = [16, 32, 64]
num_heads_options = [[1], [2], [4], [1, 2], [2, 4], [1, 2, 4]]


best_hyperparameters = None
best_metrics = {
    'accuracy': 0,
    'f1_score': 0,
    'auc_roc': 0
}


for lr, batch_size, num_heads in itertools.product(learning_rates, batch_sizes, num_heads_options):
    print(f"Testing hyperparameters: LR={lr}, Batch Size={batch_size}, Num Heads={num_heads}")


    custom_data_module = CustomDataModule(batch_size=batch_size,
                                          num_workers=num_workers,
                                          data_directory=path,
                                          data_frame=df,
                                          validation_fold=1,
                                          testing_fold=2,
                                          esc_10_flag=True,
                                          file_column='filename',
                                          label_column='category',
                                          sampling_rate=44100,
                                          new_sampling_rate=16000,
                                          sample_length_seconds=1
                                          )


    kfolds = 4
    kf = KFold(n_splits=kfolds, shuffle=True)
    test_accuracies = []
    test_f1_scores = []
    test_auc_rocs = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(custom_data_module)):


        train_dataset = torch.utils.data.Subset(custom_data_module, train_idx)
        val_dataset = torch.utils.data.Subset(custom_data_module, val_idx)


        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)


        conv_transformer_model = ConvTransformer(num_classes, input_channels=1, input_length=16000, num_heads=num_heads)
        conv_transformer_module = ConvTransformerModule(conv_transformer_model, lr=lr)
        trainer = pl.Trainer(max_epochs=3, logger=wandb_logger)


        trainer.fit(conv_transformer_module, train_dataloader, val_dataloader)


        test_dataloader = custom_data_module.test_dataloader()
        accuracy, f1, auc_roc, _ = evaluate_on_test_set(conv_transformer_model, test_dataloader)


        test_accuracies.append(accuracy)
        test_f1_scores.append(f1)
        test_auc_rocs.append(auc_roc)


    mean_accuracy = np.mean(test_accuracies)
    mean_f1 = np.mean(test_f1_scores)
    mean_auc_roc = np.mean(test_auc_rocs)


    print(f"Mean Accuracy: {mean_accuracy}")
    print(f"Mean F1 Score: {mean_f1}")
    print(f"Mean AUC-ROC Score: {mean_auc_roc}")


    if mean_accuracy > best_metrics['accuracy']:
        best_metrics['accuracy'] = mean_accuracy
        best_metrics['f1_score'] = mean_f1
        best_metrics['auc_roc'] = mean_auc_roc
        best_hyperparameters = {
            'learning_rate': lr,
            'batch_size': batch_size,
            'num_heads': num_heads
        }


print("Best Hyperparameters:")
print(best_hyperparameters)
print("Corresponding Metrics:")
print(best_metrics)
