<a href="https://colab.research.google.com/github/Frodo-Swaggins/COMP702PROJECT/blob/main/NovelConformerForSpeakerIdentificationWithDWSC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Import tarfile and extract the file
import tarfile

# Specify the path to the .tar.gz file in your Google Drive
file_path = '/content/drive/MyDrive/train-other-500.tar.gz'  # Change this to the actual path of your .tar.gz file in Google Drive

# Specify the destination folder where you want to extract the contents
destination_path = '/content/train-other-500-uncompressed/'  # You can change this to your desired destination folder

# Open the .tar.gz file and extract all contents
with tarfile.open(file_path, 'r:gz') as tar:
    tar.extractall(path=destination_path)

# Verify by listing the contents of the destination folder
import os
print(os.listdir(destination_path))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import numpy as np
import librosa
import soundfile as sf

def add_noise(audio, noise_factor=0.005):
    """Adds white noise to the audio signal."""
    noise = np.random.randn(len(audio))  # Generate random noise
    augmented_audio = audio + noise_factor * noise  # Add noise to the audio signal
    return augmented_audio



def create_noisy_dataset(original_folder, noisy_folder, noise_factor):
    """
    Creates a noisy dataset by adding noise to each audio file in the original dataset,
    and saves it directly to Google Drive.

    Args:
        original_folder (str): Path to the folder with the original clean dataset.
        noisy_folder (str): Path to the folder in Google Drive where the noisy dataset will be saved.
        noise_factor (float): Factor to control the amount of noise added.
    """
    # Ensure the noisy folder exists in Google Drive
    if not os.path.exists(noisy_folder):
        os.makedirs(noisy_folder)

    # Loop through all .flac files in the original folder
    for root, dirs, files in os.walk(original_folder):
        for file in files:
            if file.endswith('.flac'):  # Process only .flac files
                file_path = os.path.join(root, file)

                # Load the original audio
                audio, sr = librosa.load(file_path, sr=None)

                # Add noise to the audio
                noisy_audio = add_noise(audio, noise_factor=noise_factor)

                # Create the corresponding path for the noisy audio
                relative_path = os.path.relpath(root, original_folder)
                noisy_subfolder = os.path.join(noisy_folder, relative_path)

                # Ensure the subfolder exists in Google Drive
                if not os.path.exists(noisy_subfolder):
                    os.makedirs(noisy_subfolder)

                # Save the noisy audio directly to Google Drive
                noisy_file_path = os.path.join(noisy_subfolder, file)
                sf.write(noisy_file_path, noisy_audio, sr)

                print(f"Created noisy file: {noisy_file_path}")


def main():
    """
    Main function to initialize paths and create the noisy dataset directly in Google Drive.
    """
    # Path to the original dataset on Colab
    original_folder = '/content/train-other-500-uncompressed/LibriSpeech'

    # Path to the noisy dataset folder in Google Drive
    noisy_folder = '/content/drive/MyDrive/noisy_dataset'  # This will save directly to Google Drive

    # Set noise factor for augmentation
    noise_factor = 0.05

    # Create the noisy dataset directly in Google Drive
    create_noisy_dataset(original_folder, noisy_folder, noise_factor)

if __name__ == '__main__':
    main()


In [None]:
!pip install torch torchaudio

In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device())


In [None]:
import os
import librosa
import numpy as np

def remove_silence(audio, sr, top_db=30):
    """
    Remove silence from the audio using a decibel threshold.
    """
    trimmed_audio, _ = librosa.effects.trim(audio, top_db=top_db)
    print(f"Audio length after silence removal: {len(trimmed_audio)} samples")
    return trimmed_audio

def extract_mfcc(audio, sr, segment_length_sec=4.0, overlap_fraction=0.5, n_mfcc=13):
    """
    Extracts MFCCs from the given audio signal in fixed-length segments with overlap.
    """
    segment_length_samples = int(segment_length_sec * sr)
    overlap_samples = int(overlap_fraction * segment_length_samples)
    step_size = segment_length_samples - overlap_samples

    print("Extracting MFCCs...")

    mfccs = []
    for start_idx in range(0, len(audio) - segment_length_samples + 1, step_size):
        segment = audio[start_idx:start_idx + segment_length_samples]
        mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=n_mfcc)
        mfccs.append(mfcc)

    print(f"Extracted {len(mfccs)} MFCC segments")
    return mfccs

def process_speaker_files(dataset_path, speaker_subfolder, max_mfcc_per_speaker=500, segment_length_sec=4.0, overlap_fraction=0.5, top_db=30):
    """
    Process all .flac files for a given speaker (including subfolders), remove silence, and extract MFCCs.
    """
    speaker_path = os.path.join(dataset_path, speaker_subfolder)
    mfcc_segments = []

    print(f"Processing speaker folder: {speaker_path}")

    for root, dirs, files in os.walk(speaker_path):
        print(f"Searching in: {root}")
        for file in files:
            if file.endswith('.flac'):
                file_path = os.path.join(root, file)
                print(f"Processing file: {file_path}")

                audio, sr = librosa.load(file_path, sr=None)
                print(f"Audio loaded, length: {len(audio)} samples")

                if len(audio) == 0:
                    print(f"Warning: Empty audio file {file_path}")
                    continue

                trimmed_audio = remove_silence(audio, sr, top_db=top_db)

                if len(trimmed_audio) == 0:
                    print(f"Warning: Silence removal resulted in empty audio for {file_path}")
                    continue

                mfccs = extract_mfcc(trimmed_audio, sr, segment_length_sec, overlap_fraction)
                mfcc_segments.extend(mfccs)

                if len(mfcc_segments) >= max_mfcc_per_speaker:
                    mfcc_segments = mfcc_segments[:max_mfcc_per_speaker]
                    break

        if len(mfcc_segments) >= max_mfcc_per_speaker:
            break

    print(f"Total MFCC segments for speaker {speaker_subfolder}: {len(mfcc_segments)}")
    return mfcc_segments

def save_mfccs_to_drive(mfccs, output_folder, speaker_subfolder, dataset_type):
    """
    Save each speaker's MFCCs as a .npy file with noisy or clean suffix in the speaker's folder.
    """
    speaker_folder = os.path.join(output_folder, speaker_subfolder)
    if not os.path.exists(speaker_folder):
        os.makedirs(speaker_folder)
        print(f"Created folder: {speaker_folder}")

    output_file = os.path.join(speaker_folder, f"{speaker_subfolder}_{dataset_type}_mfccs.npy")
    np.save(output_file, np.array(mfccs))
    print(f"Saved MFCCs for speaker {speaker_subfolder} ({dataset_type}) to {output_file}")

def process_datasets(noisy_dataset, clean_dataset, output_folder, flag, max_mfcc_per_speaker=500, segment_length_sec=4.0, overlap_fraction=0.5, top_db=30):
    print("Starting dataset processing...")
    print(f"Noisy dataset path: {noisy_dataset}")
    print(f"Clean dataset path: {clean_dataset}")

    for speaker_subfolder in os.listdir(noisy_dataset):
        noisy_speaker_path = os.path.join(noisy_dataset, speaker_subfolder)
        clean_speaker_path = os.path.join(clean_dataset, speaker_subfolder)

        if os.path.isdir(noisy_speaker_path) and os.path.isdir(clean_speaker_path):
            print(f"Processing speaker: {speaker_subfolder}")

            # Check if MFCC files already exist for this speaker
            noisy_output_file = os.path.join(output_folder, speaker_subfolder, f"{speaker_subfolder}_noisy_mfccs.npy")
            clean_output_file = os.path.join(output_folder, speaker_subfolder, f"{speaker_subfolder}_clean_mfccs.npy")

            if os.path.exists(noisy_output_file) and os.path.exists(clean_output_file):
                print(f"Skipping speaker {speaker_subfolder} as MFCC files already exist.")
                continue

            # Process the noisy dataset for this speaker with "_noisy" suffix
            noisy_mfccs = process_speaker_files(noisy_dataset, speaker_subfolder, max_mfcc_per_speaker, segment_length_sec, overlap_fraction, top_db)
            print(f"Extracted {len(noisy_mfccs)} MFCCs from noisy dataset for speaker {speaker_subfolder}")

            # Save MFCCs for noisy dataset with "noisy" suffix
            save_mfccs_to_drive(noisy_mfccs, output_folder, speaker_subfolder, dataset_type='noisy')

            # Process the clean dataset for this speaker with "_clean" suffix
            clean_mfccs = process_speaker_files(clean_dataset, speaker_subfolder, max_mfcc_per_speaker, segment_length_sec, overlap_fraction, top_db)
            print(f"Extracted {len(clean_mfccs)} MFCCs from clean dataset for speaker {speaker_subfolder}")

            # Save MFCCs for clean dataset with "clean" suffix
            save_mfccs_to_drive(clean_mfccs, output_folder, speaker_subfolder, dataset_type='clean')


noisy_dataset_path = '/content/drive/MyDrive/noisy_dataset/train-other-500'
clean_dataset_path = '/content/train-other-500-uncompressed/LibriSpeech/train-other-500'
output_folder = '/content/drive/MyDrive/mfcc_datasets_new'

process_datasets(noisy_dataset_path, clean_dataset_path, output_folder)


In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set paths
dataset_path = '/content/drive/MyDrive/mfcc_datasets_new'  # Replace with the path to your dataset
target_length = 500  # Target number of segments per MFCC

# Loop through each speaker folder
for speaker_folder in os.listdir(dataset_path):
    speaker_path = os.path.join(dataset_path, speaker_folder)

    # Check if the path is a directory
    if os.path.isdir(speaker_path):
        # Loop through all .npy files in the speaker folder
        for mfcc_file in os.listdir(speaker_path):
            # Process only files without '_padded' or '_mask' in the filename
            if mfcc_file.endswith('.npy') and not mfcc_file.endswith('_padded.npy') and not mfcc_file.endswith('_mask.npy'):
                mfcc_file_path = os.path.join(speaker_path, mfcc_file)

                # Define paths for padded and mask files
                padded_mfcc_path = os.path.join(speaker_path, f"{os.path.splitext(mfcc_file)[0]}_padded.npy")
                mask_path = os.path.join(speaker_path, f"{os.path.splitext(mfcc_file)[0]}_mask.npy")

                # Load the MFCC file and move it to GPU
                mfcc = torch.tensor(np.load(mfcc_file_path), device=device)  # Shape should be (segments, n_mfcc, features)

                # Get current number of segments (1st dimension) and features (3rd dimension)
                current_length = mfcc.shape[0]
                n_features = mfcc.shape[2]  # Number of MFCC features per segment

                # Create a mask with shape (target_length, n_mfcc, n_features)
                mask = torch.ones((target_length, mfcc.shape[1], n_features), dtype=torch.bool, device=device)

                if current_length < target_length:
                    # Pad the MFCC to the target length along the first dimension (segments)
                    padded_mfcc = F.pad(mfcc, (0, 0, 0, 0, 0, target_length - current_length), "constant", 0)
                    # Update the mask to mark padded segments
                    mask[current_length:, :, :] = False
                else:
                    # Truncate if the MFCC length is longer than the target length
                    padded_mfcc = mfcc[:target_length, :, :]

                # Move data back to CPU for saving
                padded_mfcc = padded_mfcc.cpu().numpy()
                mask = mask.cpu().numpy()

                # Save the padded MFCC and mask, overwriting existing files if they exist
                np.save(padded_mfcc_path, padded_mfcc)
                np.save(mask_path, mask)

                print(f"Processed {mfcc_file}: Saved and overwritten padded MFCC and mask.")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

# Swish Activation Function
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

# Updated ConvSubsampling layer to handle channel dimension properly
class ConvSubsampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvSubsampling, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.out_channels = out_channels

    def forward(self, x):
        # x should be of shape (batch, 1, time, features)
        x = x.unsqueeze(1)  # Add channel dimension if missing, resulting in (batch, 1, time, features)
        x = self.conv(x)  # Perform subsampling
        b, c, t, f = x.size()
        x = x.permute(0, 2, 1, 3).contiguous().view(b, t, c * f)  # Reshape to (batch, time, out_channels * features // 4)
        return x



# Feed Forward Module
class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super(FeedForwardModule, self).__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, d_model * expansion_factor)
        self.activation = Swish()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_model * expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)
        return x + residual

# Multi-Head Self-Attention Module with Relative Positional Embedding
class MHSA(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MHSA, self).__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        residual = x
        x = self.layer_norm(x)
        x, _ = self.attention(x, x, x)  # Self-attention
        x = self.dropout(x)
        return x + residual

#Convolution Module
class ConvolutionModule(nn.Module):
    def __init__(self, d_model, kernel_size=31, dropout=0.1):
        super(ConvolutionModule, self).__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.pointwise_conv1 = nn.Conv1d(d_model, 2 * d_model, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=d_model)
        self.batch_norm = nn.BatchNorm1d(d_model)
        self.swish = Swish()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = x.transpose(1, 2)  # (batch, d_model, time)
        x = self.pointwise_conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.swish(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        x = x.transpose(1, 2)  # (batch, time, d_model)

        # Adjust x to match the residual shape in the time dimension if necessary
        if x.size(1) != residual.size(1):
            min_time_dim = min(x.size(1), residual.size(1))
            x = x[:, :min_time_dim, :]
            residual = residual[:, :min_time_dim, :]
            #print(f"Adjusted shapes for residual connection: x.shape={x.shape}, residual.shape={residual.shape}")

        return x + residual


# Conformer Block
class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, kernel_size, dropout=0.1):
        super(ConformerBlock, self).__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.mhsa = MHSA(d_model, n_heads, dropout=dropout)
        self.conv = ConvolutionModule(d_model, kernel_size=kernel_size, dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.ffn1(x)
        x = self.mhsa(x)
        x = self.conv(x)
        x = self.ffn2(x)
        return self.layer_norm(x)

class ConformerEncoder(nn.Module):
    def __init__(self, input_dim, d_model, n_heads, num_blocks, kernel_size=31, dropout=0.1, num_classes=10):
        super(ConformerEncoder, self).__init__()
        self.subsampling = ConvSubsampling(input_dim, d_model)  # Initialize with given d_model
        self.linear = nn.Linear(d_model, d_model)  # Initialize Linear layer; will set dynamically in forward
        self.dropout = nn.Dropout(dropout)

        # Define Conformer Blocks
        self.conformer_blocks = nn.ModuleList([
            ConformerBlock(d_model, n_heads, kernel_size, dropout=dropout) for _ in range(num_blocks)
        ])

        # Output layer for classification
        self.output_layer = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # Subsampling layer
        x = self.subsampling(x)
        # Capture the feature dimension after subsampling
        batch_size, time, features = x.shape

        # Dynamically set Linear layer input dimension if necessary
        if self.linear.in_features != features:
            self.linear = nn.Linear(features, d_model).to(x.device)  # Adjust Linear layer input size
            print(f"Dynamically adjusted Linear layer input size to: {features}")

        # Pass through Linear and Dropout layers
        x = self.linear(x)
        x = self.dropout(x)

        # Pass through each Conformer block
        for i, block in enumerate(self.conformer_blocks):
            x = block(x)

        # Global pooling and final output layer
        x = self.output_layer(x.mean(dim=1))  # Pool across time dimension

        return x


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, input_dim, d_model, kernel_size, bias=True):
        super().__init__()
        self.depthwise_conv = nn.Conv1d(in_channels=input_dim, out_channels=d_model, kernel_size=kernel_size, groups=input_dim, padding=kernel_size // 2, bias=False)
        self.pointwise_conv = nn.Conv1d(in_channels=input_dim, out_channels=d_model, kernel_size=1, padding=0, bias=bias)
    def forward(self, x):
        return F.relu(self.pointwise_conv(self.depthwise_conv(x)))


In [None]:
import os
import random
import shutil

# Define dataset path and split ratios
data_path = '/content/drive/MyDrive/mfcc_datasets_new'  # Path to your MFCC dataset
train_ratio, val_ratio = 0.8, 0.2

# List all speaker directories and shuffle for randomness
speakers = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))]
random.shuffle(speakers)

# Calculate split indices
total_speakers = len(speakers)
train_idx = int(train_ratio * total_speakers)

# Split speakers
train_speakers = speakers[:train_idx]
val_speakers = speakers[train_idx:]

# Function to copy speaker data to designated split folders
def copy_speakers(speakers_list, subset_name):
    subset_path = os.path.join(data_path, subset_name)
    os.makedirs(subset_path, exist_ok=True)
    for speaker in speakers_list:
        shutil.copytree(os.path.join(data_path, speaker), os.path.join(subset_path, speaker), dirs_exist_ok=True)

# Create train and val splits
copy_speakers(train_speakers, 'train')
copy_speakers(val_speakers, 'val')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import numpy as np
import os
import time

# Import necessary library for profiling FLOPs
!pip install torchprofile
import torchprofile

In [None]:
# Paths to save model checkpoints
checkpoint_dir = '/content/drive/MyDrive/model_checkpoints/5'
os.makedirs(checkpoint_dir, exist_ok=True)

# Define fixed dimensions
target_height = 500  # Desired time dimension
target_width = 126   # Desired feature dimension

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Function to calculate FLOPs and parameter count
import torchprofile

# Function to calculate FLOPs and parameter count
def get_model_complexity(model, input_shape):
    # Calculate MACs for model
    macs = torchprofile.profile_macs(model, torch.randn(input_shape).to(device))
    # Calculate total parameters for model
    params = sum(p.numel() for p in model.parameters())
    print(f"Model MACs: {macs / 1e9:.2f} GMACs, Parameters: {params / 1e6:.2f}M")

# Define checkpoint saving function
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, filepath):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

# Loading from checkpoint if exists
def load_checkpoint(checkpoint_path, model, optimizer):
    import torch
    import os

    # Define the checkpoint path
    checkpoint_path = os.path.join(checkpoint_path, 'latest_checkpoint.pth')
    start_epoch = 0  # Default start epoch

    # Attempt to load the checkpoint
    if os.path.exists(checkpoint_path):
        try:
            print(f"Attempting to load checkpoint from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path)

            # Load model and optimizer states
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
            print(f"Checkpoint loaded successfully, resuming from epoch {start_epoch}")
            return start_epoch
        except RuntimeError as e:
            # Handle the mismatch by creating a new model
            print(f"Checkpoint mismatch detected. Error: {e}")
            print("Creating a new model and starting from scratch.")

            # Reinitialize model and optimizer for a fresh start
            model = ConformerEncoder(
                input_dim=input_dim,
                d_model=d_model,
                n_heads=n_heads,
                num_blocks=num_blocks,
                kernel_size=kernel_size,
                dropout=0.1,
                num_classes=num_classes
            ).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
            return 0  # Return 0 to start from scratch
    else:
        print("No checkpoint found. Starting training from scratch.")
        return start_epoch



class SpeakerMFCCDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = Path(root_dir)
        self.samples = []

        # Collect paths to all padded MFCC and mask files within each speaker folder
        for speaker_folder in self.root_dir.iterdir():
            if speaker_folder.is_dir():
                for mfcc_file in speaker_folder.glob("*_padded.npy"):
                    mask_file = mfcc_file.with_name(mfcc_file.stem.replace("_padded", "_mask") + ".npy")
                    if mask_file.exists():
                        self.samples.append((mfcc_file, mask_file, speaker_folder.name))

        # Create a label-to-index mapping
        self.label_to_idx = {speaker: idx for idx, speaker in enumerate(set(s[2] for s in self.samples))}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mfcc_path, mask_path, speaker = self.samples[idx]

        # Load MFCC and mask
        mfcc = torch.tensor(np.load(mfcc_path), dtype=torch.float32)
        mask = torch.tensor(np.load(mask_path), dtype=torch.bool)

        # Get label index from speaker name
        label_idx = self.label_to_idx[speaker]

        return mfcc, mask, label_idx

# Collate function to enforce consistent shapes
def collate_fn(batch):
    mfccs, masks, labels = [], [], []
    for mfcc, mask, label in batch:
        # Ensure MFCC tensor has shape [1, time, features]
        if mfcc.dim() == 2:
            mfcc = mfcc.unsqueeze(0)
        elif mfcc.dim() == 3 and mfcc.shape[0] != 1:
            mfcc = mfcc.mean(dim=0, keepdim=True)

        # Pad or crop to ensure [1, 500, 126] shape
        mfcc = F.pad(mfcc, (0, target_width - mfcc.shape[-1], 0, target_height - mfcc.shape[-2]), mode='constant', value=0)
        mfcc = mfcc[:, :target_height, :target_width]
        mfccs.append(mfcc)

        # Ensure mask tensor has shape [1, 500, 126]
        if mask.dim() == 2:
            mask = mask.unsqueeze(0)
        elif mask.dim() == 3 and mask.shape[0] != 1:
            mask = mask.float().mean(dim=0, keepdim=True)

        mask = F.pad(mask, (0, target_width - mask.shape[-1], 0, target_height - mask.shape[-2]), mode='constant', value=0)
        mask = mask[:, :target_height, :target_width]
        masks.append(mask)

        labels.append(label)

    # Stack tensors and remove unnecessary dimensions
    mfccs_batch = torch.stack(mfccs).squeeze(1)
    masks_batch = torch.stack(masks).squeeze(1)
    labels_batch = torch.tensor(labels)

    return mfccs_batch, masks_batch, labels_batch


# Initialize model, optimizer, and loss function
input_dim = target_width
d_model = 144
n_heads = 4
num_blocks = 8
kernel_size = 32
num_classes = 10  # Replace with actual number of classes
model = ConformerEncoder(input_dim, d_model, n_heads, num_blocks, kernel_size, 0.1, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

input_shape = (1, target_height, target_width)  # (batch_size, time, features)

# Calculate and display model complexity before training
print("Calculating initial model complexity:")
get_model_complexity(model, input_shape=input_shape)


# Load checkpoint if available
start_epoch = load_checkpoint(os.path.join(checkpoint_dir, 'latest_checkpoint.pth'), model, optimizer)

# Training function with detailed logging and profiling
def train_epoch(model, loader, optimizer, criterion, epoch):
    model.train()
    total_loss, correct = 0, 0
    start_epoch_time = time.time()

    for batch_idx, (data, labels) in enumerate(loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} - Batch {batch_idx}/{len(loader)}: Loss = {loss.item():.4f}")

    avg_loss = total_loss / len(loader)
    accuracy = correct / len(loader.dataset)
    epoch_duration = time.time() - start_epoch_time
    print(f"Epoch {epoch} completed in {epoch_duration / 60:.2f} mins")

    # Display computational metrics after each epoch
    get_model_complexity(model, input_shape=(8, target_width))  # Adjust batch size and input shape
    return avg_loss, accuracy

# Validation function
def validate_epoch(model, loader, criterion, epoch):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(loader):
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()

    avg_loss = total_loss / len(loader)
    accuracy = correct / len(loader.dataset)
    print(f"Validation - Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}")
    return avg_loss, accuracy

# Calculate and display model complexity before training
print("Calculating initial model complexity:")

input_shape = (1, target_height, target_width)
get_model_complexity(model, input_shape=input_shape)

# Load checkpoint if available
start_epoch = load_checkpoint(os.path.join(checkpoint_dir, 'latest_checkpoint.pth'), model, optimizer)


# Example usage with train and val directories
train_dataset = SpeakerMFCCDataset('/content/drive/MyDrive/mfcc_datasets_new/train')
val_dataset = SpeakerMFCCDataset('/content/drive/MyDrive/mfcc_datasets_new/val')

# Create DataLoaders with custom collate function
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False,collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False,collate_fn=collate_fn)

# Training loop
num_epochs = 20
for epoch in range(start_epoch, num_epochs):
    print(f"Starting epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, epoch)
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, epoch)

    # Save checkpoint at the end of each epoch
    save_checkpoint(model, optimizer, epoch, train_loss, val_loss, os.path.join(checkpoint_dir, 'latest_checkpoint.pth'))

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
