<a href="https://colab.research.google.com/github/Somie12/Speech-Synthesis-for-Low-Resource-Language/blob/main/5.%20LID%20with%20Conformer%20and%20Ecapa-TDNN%20Architecture/Conformer/Conformer_LID_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Installing required packages
!pip install torch torchaudio transformers datasets soundfile librosa numpy pandas matplotlib seaborn



Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
# Importing necessary libraries
import os
import time
import json
import shutil
import zipfile
import logging
import torch
import torchaudio
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc
from sklearn.preprocessing import label_binarize
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchsummary import summary

In [4]:
# Defining paths
DRIVE_PATH = '/content/drive/MyDrive'
DATASET_DRIVE_PATH = os.path.join(DRIVE_PATH, 'LID_Training', 'train_datasets')
OUTPUT_PATH = '/content'  # Using Colab's temporary storage instead of Drive

# Creating directories for extracted data in Colab space
EXTRACTED_DATA_PATH = os.path.join(OUTPUT_PATH, 'language_data')
os.makedirs(EXTRACTED_DATA_PATH, exist_ok=True)

# Creating directory for final processed data and model checkpoints in Drive
PROCESSED_DATA_DIR = os.path.join(DRIVE_PATH, 'LID_Training', 'processed_data')
MODEL_CHECKPOINT_DIR = os.path.join(DRIVE_PATH, 'LID_Training', 'model_checkpoints')
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)

# Extracting dataset zip files to Colab space
def extract_datasets():
    languages = ['Hindi', 'English', 'Chinese']
    for lang in languages:
        zip_path = os.path.join(DATASET_DRIVE_PATH, f'{lang}_Datasets.zip')
        if os.path.exists(zip_path):
            extract_dir = os.path.join(EXTRACTED_DATA_PATH, lang.lower())
            os.makedirs(extract_dir, exist_ok=True)
            print(f"Extracting {lang} dataset to Colab space...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_dir)
            print(f"{lang} dataset extracted to {extract_dir}")
        else:
            print(f"Warning: {zip_path} not found!")


extract_datasets()

Extracting Hindi dataset to Colab space...
Hindi dataset extracted to /content/language_data/hindi
Extracting English dataset to Colab space...
English dataset extracted to /content/language_data/english
Extracting Chinese dataset to Colab space...
Chinese dataset extracted to /content/language_data/chinese


In [5]:
# Setting up logging
def setup_logging(log_dir):
    """Setup logging configuration"""
    os.makedirs(log_dir, exist_ok=True)

    log_file = os.path.join(log_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

    return logging.getLogger('LID_Training')

In [6]:
# Voice Activity Detection (VAD) function for detecting speech segments
def apply_vad(waveform, sample_rate=16000, threshold_db=-35, min_silence_duration_ms=300):
    """
    Applying Voice Activity Detection to remove silence from audio

    Args:
        waveform: Audio signal
        sample_rate: Audio sample rate
        threshold_db: Energy threshold in dB below which frames are considered silence
        min_silence_duration_ms: Minimum silence duration in milliseconds

    Returns:
        Filtered audio with silence removed
    """
    # Converting to mono if stereo
    if len(waveform.shape) > 1 and waveform.shape[0] > 1:
        waveform = np.mean(waveform, axis=0)

    # Calculating frame length and hop length
    frame_length = int(sample_rate * 0.025)  # 25ms frames
    hop_length = int(sample_rate * 0.010)    # 10ms hop

    # Calculating energy in dB for each frame
    energy = librosa.feature.rms(y=waveform, frame_length=frame_length, hop_length=hop_length)[0]
    energy_db = librosa.amplitude_to_db(energy, ref=np.max)

    # Creating mask for frames with energy above threshold
    mask = energy_db > threshold_db

    # Converting frame-level mask to sample-level mask
    sample_mask = np.zeros_like(waveform, dtype=bool)
    for i, val in enumerate(mask):
        if val:
            start_sample = i * hop_length
            end_sample = min(start_sample + frame_length, len(waveform))
            sample_mask[start_sample:end_sample] = True

    # Applying minimum silence duration constraint
    min_silence_samples = int(min_silence_duration_ms * sample_rate / 1000)

    # Finding silence segments
    silence_starts = np.where(np.diff(sample_mask.astype(int)) == -1)[0] + 1
    silence_ends = np.where(np.diff(sample_mask.astype(int)) == 1)[0] + 1

    # Handling case where audio starts with silence
    if not sample_mask[0]:
        silence_starts = np.insert(silence_starts, 0, 0)

    # Handling case where audio ends with silence
    if not sample_mask[-1]:
        silence_ends = np.append(silence_ends, len(sample_mask))

    # Keeping only speech segments (inverse of silence)
    speech_segments = []
    last_end = 0

    for start, end in zip(silence_starts, silence_ends):
        # If silence duration is less than threshold, we'll consider it as speech
        if end - start < min_silence_samples:
            continue

        if start > last_end:
            speech_segments.append(waveform[last_end:start])

        last_end = end

    # Adding remaining speech at the end if any
    if last_end < len(waveform):
        speech_segments.append(waveform[last_end:])

    # If no speech detected, return original
    if not speech_segments:
        return waveform

    # Concatenating all speech segments
    return np.concatenate(speech_segments)

class LanguageAudioDataset(Dataset):
    def __init__(self, root_dir, languages=['hindi', 'english', 'chinese'], max_samples_per_lang=15000, segment_length=3,
                 apply_vad=True):
        """
        Dataset for language identification from audio files

        Args:
            root_dir: Root directory containing language folders
            languages: List of language names (folder names)
            max_samples_per_lang: Maximum number of samples per language
            segment_length: Length of audio segments in seconds
            apply_vad: Whether to apply Voice Activity Detection
        """
        self.root_dir = root_dir
        self.languages = languages
        self.segment_length = segment_length  # in seconds
        self.sample_rate = 16000  # standard sample rate
        self.samples = []
        self.labels = []
        self.apply_vad = apply_vad

        for i, lang in enumerate(languages):
            lang_dir = os.path.join(root_dir, lang)
            if not os.path.exists(lang_dir):
                print(f"Warning: Directory {lang_dir} not found!")
                continue

            audio_files = []
            for root, _, files in os.walk(lang_dir):
                for file in files:
                    if file.endswith(('.wav', '.mp3', '.flac')):
                        audio_files.append(os.path.join(root, file))

            # Limiting the number of samples per language
            audio_files = audio_files[:max_samples_per_lang]
            print(f"Found {len(audio_files)} files for {lang}")

            for audio_file in audio_files:
                self.samples.append(audio_file)
                self.labels.append(i)

        print(f"Total samples: {len(self.samples)}")
        self.label_to_lang = {i: lang for i, lang in enumerate(languages)}
        self.lang_to_label = {lang: i for i, lang in enumerate(languages)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path = self.samples[idx]
        label = self.labels[idx]


        try:
            waveform, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)

            # Applying Voice Activity Detection if enabled
            if self.apply_vad:
                waveform = apply_vad(waveform, sample_rate=self.sample_rate)

        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # Returns a zero array if file can't be loaded
            return torch.zeros(self.segment_length * self.sample_rate), label

        # Processes audio into 3-second segments
        segments = self._segment_audio(waveform)

        # If no valid segments, creates a zero segment
        if len(segments) == 0:
            return torch.zeros(self.segment_length * self.sample_rate), label

        # Returns random segment (during training) or first segment (during validation)
        segment_idx = np.random.randint(0, len(segments)) if len(segments) > 1 else 0
        return torch.from_numpy(segments[segment_idx].astype(np.float32)), label

    def _segment_audio(self, waveform):
        """Split audio into 3-second segments, discard extra seconds"""
        segment_samples = self.segment_length * self.sample_rate
        segments = []

        # If audio is shorter than 3 seconds, skip it
        if len(waveform) < segment_samples:
            return segments

        # Split longer audio into 3-second segments
        for i in range(0, len(waveform) - segment_samples + 1, segment_samples):
            segment = waveform[i:i + segment_samples]
            segments.append(segment)

        return segments

In [7]:
"""
Language Identification using Conformer Architecture (PyTorch)

This model processes raw audio or log-mel spectrograms to classify the spoken language.
Key modules:
- extract_features: Converts waveform to normalized log-mel spectrogram.
- ConformerBlock: Combines self-attention, convolution, and feedforward layers.
- Conformer: Stacked blocks with positional encoding and classification head.
- LIDModel: Full pipeline for LID, handling raw audio input and prediction.
"""




import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections and split into h heads
        query, key, value = [
            l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linear_layers, (query, key, value))
        ]

        # Apply attention on all projected vectors in batch
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)
        p_attn = self.dropout(p_attn)
        x = torch.matmul(p_attn, value)

        # Combine heads
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)

class ConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=31, dropout=0.1):
        super(ConvModule, self).__init__()

        self.layer_norm = nn.LayerNorm(d_model)

        # Pointwise convolution
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)

        # 1D depthwise convolution
        padding = (kernel_size - 1) // 2
        self.depthwise_conv = nn.Conv1d(
            d_model, d_model, kernel_size=kernel_size, padding=padding, groups=d_model
        )

        self.batch_norm = nn.BatchNorm1d(d_model)
        self.activation = nn.SiLU()  # SiLU (Swish) activation

        # Pointwise convolution
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x shape: [batch_size, seq_len, d_model]
        residual = x
        x = self.layer_norm(x)

        # Transpose for conv operations
        x = x.transpose(1, 2)  # [batch_size, d_model, seq_len]

        # GLU mechanism
        x = self.pointwise_conv1(x)
        x = F.glu(x, dim=1)  # Dimension halved here

        # Depthwise convolution
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)

        # Second pointwise convolution
        x = self.pointwise_conv2(x)
        x = self.dropout(x)

        # Transpose back
        x = x.transpose(1, 2)  # [batch_size, seq_len, d_model]

        # Residual connection
        return x + residual

class ConformerBlock(nn.Module):
    def __init__(self, d_model, d_ff, heads, kernel_size, dropout=0.1):
        super(ConformerBlock, self).__init__()

        self.ff1 = FeedForward(d_model, d_ff, dropout)
        self.ff1_factor = 0.5

        self.self_attn = MultiHeadedAttention(heads, d_model, dropout)
        self.attn_layer_norm = nn.LayerNorm(d_model)

        self.conv_module = ConvModule(d_model, kernel_size, dropout)

        self.ff2 = FeedForward(d_model, d_ff, dropout)
        self.ff2_factor = 0.5

        self.final_layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # First Feed Forward module
        x = x + self.ff1_factor * self.ff1(x)

        # Multi-Headed Self-Attention
        residual = x
        x = self.attn_layer_norm(x)
        x = residual + self.self_attn(x, x, x, mask)

        # Convolution module
        x = self.conv_module(x)

        # Second Feed Forward module
        x = x + self.ff2_factor * self.ff2(x)

        # Final Layer Norm
        x = self.final_layer_norm(x)

        return x

class Conformer(nn.Module):
    def __init__(self, num_classes, d_model=144, n_layers=6, n_heads=4, d_ff=256,
                 kernel_size=31, dropout=0.1, input_dim=80):
        super(Conformer, self).__init__()

        # Input projection from mel spectrogram to d_model dimension
        self.input_projection = nn.Linear(input_dim, d_model)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Conformer blocks
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, d_ff, n_heads, kernel_size, dropout)
            for _ in range(n_layers)
        ])

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, x, mask=None):
        # x shape: [batch_size, channels, time, freq]
        batch_size = x.size(0)

        # Convert to [batch_size, time, freq] if needed
        if x.dim() == 4:
            x = x.squeeze(1)  # Remove channel dimension if present

        # Reshape input to [batch_size, seq_len, input_dim]
        # Assuming input is [batch_size, n_mels, time] from extract_features
        x = x.transpose(1, 2) # [batch_size, time, n_mels]

        # Project input to d_model dimension
        x = self.input_projection(x)

        # Add positional encoding
        x = self.positional_encoding(x)

        # Apply dropout
        x = self.dropout(x)

        # Apply Conformer blocks
        for layer in self.layers:
            x = layer(x, mask)

        # Global average pooling
        x = torch.mean(x, dim=1)

        # Classification
        x = self.classifier(x)

        return x

class LIDModel(nn.Module):
   def __init__(self, num_languages=3, input_dim=80):
        super(LIDModel, self).__init__()

        # Conformer-based encoder
        self.conformer = Conformer(
            num_classes=num_languages,
            d_model=144,  # Dimension of model
            n_layers=6,   # Number of Conformer blocks
            n_heads=4,    # Number of attention heads
            d_ff=256,     # Feed forward dimension
            kernel_size=31,  # Kernel size for convolution module
            dropout=0.1,  # Dropout rate
            input_dim=input_dim  # Input dimension (mel spectrogram features)
        )

   def forward(self, x):
        # Extract features if input is raw audio
        if x.dim() == 2:
            # Input is [batch_size, time]
            # Convert to mel spectrogram features
            # Features will be [batch_size, n_mels, time]
            # Ensure features are on the same device as input x
            x = extract_features(x)
            # The Conformer expects [batch_size, time, input_dim] after reshape/projection
            # The current extract_features returns [batch_size, n_mels, time]
            # This reshape is handled inside the Conformer's forward method

        # Forward through conformer
        return self.conformer(x)

In [8]:
# Function  extract features (mel spectrograms) from audio
def extract_features(waveform, sample_rate=16000, n_mels=80, device='cpu'): # Add device argument
    """
    Convert waveform to mel spectrogram for model input

    Args:
        waveform: Audio waveform (can be on CPU or GPU)
        sample_rate: Audio sample rate
        n_mels: Number of mel bands
        device: The device the waveform is on

    Returns:
        Normalized log mel spectrogram
    """
    # Ensure the transform is on the same device as the waveform
    mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=512,
        win_length=400,
        hop_length=160,  # 10ms shift
        n_mels=n_mels
    ).to(device) # Move the transform to the specified device

    mel_spec = mel_spectrogram_transform(waveform)

    # Convert to log mel spectrogram
    log_mel = torch.log(mel_spec + 1e-9)

    # Normalize
    mean = log_mel.mean()
    std = log_mel.std()
    log_mel = (log_mel - mean) / (std + 1e-9)

    return log_mel

In [9]:
# Function for training the model with enhanced logging
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                num_epochs=30, device='cuda', save_dir=None, logger=None,
                start_epoch=0, resume_training=False):
    """
    training function with logging and metrics tracking

    Args:
        model: Model to train
        train_loader: DataLoader for training set
        val_loader: DataLoader for validation set
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        num_epochs: Number of epochs to train
        device: Device to train on
        save_dir: Directory to save checkpoints and logs
        logger: Logger instance for detailed logging
        start_epoch: Epoch to start from (for resuming training)
        resume_training: Whether this is a resumed training session

    Returns:
        Dictionary containing training history
    """
    if logger is None:
        logger = logging.getLogger('LID_Training')

    # Creating directories for saving various outputs
    os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'logs'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'visualizations'), exist_ok=True)
    # Create directory for per-epoch models
    os.makedirs(os.path.join(save_dir, 'epoch_models'), exist_ok=True)

    # Keeping track of metrics
    best_val_acc = 0.0
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    # For detailed per-epoch logging
    epoch_metrics = []

    # For resuming training
    if resume_training and start_epoch > 0:
        history_path = os.path.join(save_dir, 'logs', 'training_history.json')
        if os.path.exists(history_path):
            try:
                with open(history_path, 'r') as f:
                    history_data = json.load(f)
                    train_losses = history_data['train_loss'][:start_epoch]
                    val_losses = history_data['val_loss'][:start_epoch]
                    train_accs = history_data['train_acc'][:start_epoch]
                    val_accs = history_data['val_acc'][:start_epoch]
                    epoch_metrics = history_data['epoch_metrics'][:start_epoch]

                    # Finding the best validation accuracy so far
                    for metric in epoch_metrics:
                        if metric['val_acc'] > best_val_acc:
                            best_val_acc = metric['val_acc']

                    logger.info(f"Loaded training history up to epoch {start_epoch}, best val acc: {best_val_acc:.2f}%")
            except Exception as e:
                logger.warning(f"Failed to load training history: {e}")

    # Training start time
    start_time = time.time()
    logger.info(f"Starting training for {num_epochs} epochs (from epoch {start_epoch+1})")

    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()

        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        batch_losses = []

        logger.info(f"Epoch {epoch+1}/{num_epochs}")

        # Creating progress bar for training
        train_pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}")

        for batch_idx, (inputs, targets) in enumerate(train_pbar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Converting inputs to spectrograms if needed
            if inputs.dim() == 2:  # [batch, time]
                specs = []
                for waveform in inputs:
                    # Pass the device to extract_features
                    spec = extract_features(waveform, device=device)
                    specs.append(spec)
                inputs = torch.stack(specs)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Statistics
            train_loss += loss.item()
            batch_losses.append(loss.item())
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Updating progress bar
            train_pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100.*correct/total:.2f}%"
            })

        # Calculating training metrics
        train_acc = 100. * correct / total
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        # Creates progress bar for validation
        val_pbar = tqdm(val_loader, desc=f"Val Epoch {epoch+1}")

        with torch.no_grad():
            for inputs, targets in val_pbar:
                inputs, targets = inputs.to(device), targets.to(device)

                # Converting inputs to spectrograms if needed
                if inputs.dim() == 2:  # [batch, time]
                    specs = []
                    for waveform in inputs:
                        # Pass the device to extract_features
                        spec = extract_features(waveform, device=device)
                        specs.append(spec)
                    inputs = torch.stack(specs)

                outputs = model(inputs)
                loss = criterion(outputs, targets)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # Updating progress bar
                val_pbar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{100.*correct/total:.2f}%"
                })

        # Calculates validation metrics
        val_acc = 100. * correct / total
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Calculates epoch time
        epoch_time = time.time() - epoch_start_time

        # Log epoch results
        logger.info(f"Epoch {epoch+1} completed in {epoch_time:.2f}s - "
                   f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
                   f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Saving epoch metrics
        epoch_metric = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'learning_rate': optimizer.param_groups[0]['lr'],
            'epoch_time': epoch_time,
            'batch_losses': batch_losses  # Saving all batch losses for distribution analysis
        }
        epoch_metrics.append(epoch_metric)

        # Saving metrics to CSV after each epoch
        metrics_df = pd.DataFrame(epoch_metrics)
        metrics_df.to_csv(os.path.join(save_dir, 'logs', 'training_metrics.csv'), index=False)

        # Plotting and saving learning curves after each epoch
        plot_learning_curves(train_losses, val_losses, train_accs, val_accs,
                            save_path=os.path.join(save_dir, 'visualizations', f'learning_curves_epoch_{epoch+1}.png'))

        # Plotting batch loss distribution
        plt.figure(figsize=(10, 6))
        plt.hist(batch_losses, bins=30, alpha=0.7)
        plt.xlabel('Batch Loss')
        plt.ylabel('Frequency')
        plt.title(f'Batch Loss Distribution - Epoch {epoch+1}')
        plt.savefig(os.path.join(save_dir, 'visualizations', f'batch_loss_dist_epoch_{epoch+1}.png'))
        plt.close()

        # Updating the learning rate
        scheduler.step(val_loss)

        # Save model after each epoch in the epoch_models folder
        epoch_model_filename = f'model_epoch_{epoch+1}_trainloss_{train_loss:.4f}_trainacc_{train_acc:.2f}_valloss_{val_loss:.4f}_valacc_{val_acc:.2f}.pt'
        epoch_model_path = os.path.join(save_dir, 'epoch_models', epoch_model_filename)

        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'train_acc': train_acc,
            'train_loss': train_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'epoch_metrics': epoch_metrics
        }, epoch_model_path)

        logger.info(f"Saved epoch model to {epoch_model_path}")

        # Saving the model checkpoint if it's the best so far
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(save_dir, 'checkpoints', f'best_model_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'train_loss': train_loss,
            }, checkpoint_path)
            logger.info(f"New best model saved to {checkpoint_path} with validation accuracy: {val_acc:.2f}%")

        # Saving regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(save_dir, 'checkpoints', f'model_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'train_loss': train_loss,
            }, checkpoint_path)
            logger.info(f"Regular checkpoint saved to {checkpoint_path}")

    # Calculating total training time
    total_time = time.time() - start_time
    logger.info(f"Training completed in {total_time/60:.2f} minutes")

    # Saving the final model
    final_path = os.path.join(save_dir, 'checkpoints', 'final_model.pt')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_acc': val_acc,
        'val_loss': val_loss,
        'train_acc': train_acc,
        'train_loss': train_loss,
    }, final_path)
    logger.info(f"Final model saved to {final_path}")

    # Creating and saving final visualizations
    create_final_visualizations(train_losses, val_losses, train_accs, val_accs, epoch_metrics, save_dir)

    # Returning training history
    history = {
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accs,
        'val_acc': val_accs,
        'epoch_metrics': epoch_metrics
    }

    # Saving history as JSON
    with open(os.path.join(save_dir, 'logs', 'training_history.json'), 'w') as f:
        # Converting numpy values to Python native types for JSON serialization
        history_json = {
            'train_loss': [float(x) for x in train_losses],
            'val_loss': [float(x) for x in val_losses],
            'train_acc': [float(x) for x in train_accs],
            'val_acc': [float(x) for x in val_accs],
            'epoch_metrics': [{k: float(v) if isinstance(v, (np.float32, np.float64)) and k != 'batch_losses' else
                               [float(x) for x in v] if k == 'batch_losses' else v
                               for k, v in m.items()} for m in epoch_metrics]
        }
        json.dump(history_json, f, indent=4)

    return history


In [10]:
# Function for resuming training from a checkpoint
def resume_training(checkpoint_path, model, optimizer, scheduler, logger):
    """
    Resumes training from a checkpoint

    Args:
        checkpoint_path: Path to the checkpoint file
        model: Model to load weights into
        optimizer: Optimizer to load state into
        scheduler: Learning rate scheduler to load state into
        logger: Logger instance

    Returns:
        start_epoch: Epoch to start from
        model: Loaded model
        optimizer: Loaded optimizer
        scheduler: Loaded scheduler
    """
    logger.info(f"Loading checkpoint from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']

    logger.info(f"Checkpoint loaded successfully. Resuming from epoch {start_epoch}")
    logger.info(f"Loaded model with val_acc: {checkpoint['val_acc']:.2f}%, val_loss: {checkpoint['val_loss']:.4f}")

    return start_epoch, model, optimizer, scheduler



In [11]:
def plot_learning_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
    """Plot learning curves for loss and accuracy"""
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.grid(alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Training and Validation Accuracy')
    plt.grid(alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)

    plt.close()

def create_final_visualizations(train_losses, val_losses, train_accs, val_accs, epoch_metrics, save_dir):
    """Creating visualizations for the entire training process"""
    vis_dir = os.path.join(save_dir, 'visualizations')
    os.makedirs(vis_dir, exist_ok=True)

    # 1. Learning curves
    plot_learning_curves(train_losses, val_losses, train_accs, val_accs,
                         save_path=os.path.join(vis_dir, 'final_learning_curves.png'))

    # 2. Learning rate schedule
    plt.figure(figsize=(10, 6))
    learning_rates = [m['learning_rate'] for m in epoch_metrics]
    plt.plot(range(1, len(learning_rates) + 1), learning_rates, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(vis_dir, 'learning_rate_schedule.png'))
    plt.close()

    # 3. Epoch training time
    plt.figure(figsize=(10, 6))
    epoch_times = [m['epoch_time'] for m in epoch_metrics]
    plt.plot(range(1, len(epoch_times) + 1), epoch_times, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Training Time (seconds)')
    plt.title('Epoch Training Time')
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(vis_dir, 'epoch_training_times.png'))
    plt.close()

    # 4. Combined metrics plot
    plt.figure(figsize=(12, 10))

    plt.subplot(2, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')
    plt.grid(alpha=0.3)

    plt.subplot(2, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Accuracy Curves')
    plt.grid(alpha=0.3)

    plt.subplot(2, 2, 3)
    plt.plot(learning_rates, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(alpha=0.3)

    plt.subplot(2, 2, 4)
    plt.plot(epoch_times, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Time (seconds)')
    plt.title('Epoch Training Time')
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, 'combined_training_metrics.png'))
    plt.close()

    # 5. Correlation heatmap
    metrics_df = pd.DataFrame(epoch_metrics)
    corr_columns = ['train_loss', 'val_loss', 'train_acc', 'val_acc', 'learning_rate', 'epoch_time']
    corr_df = metrics_df[corr_columns].corr()

    plt.figure(figsize=(10, 8))
    sns.heatmap(corr_df, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
    plt.title('Correlation Between Training Metrics')
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, 'metrics_correlation.png'))
    plt.close()


In [12]:
def test_model(model, test_loader, criterion, device, languages, save_dir=None, logger=None):
    """
     Testing function that evaluates model performance
    and generates visualizations

    Args:
        model: Trained model
        test_loader: DataLoader for test dataset
        criterion: Loss function
        device: Device to run evaluation on
        languages: List of language names
        save_dir: Directory to save visualizations
        logger: Logger instance

    Returns:
        test_loss: Average loss on test set
        test_acc: Accuracy on test set
        per_class_metrics: Dictionary of per-class metrics
    """
    if logger is None:
        logger = logging.getLogger('LID_Training')

    # Creating directory for test results
    test_dir = os.path.join(save_dir, 'test_results')
    os.makedirs(test_dir, exist_ok=True)

    logger.info("Starting evaluation on test set")
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    # For per-class metrics
    num_classes = len(languages)
    class_correct = list(0. for i in range(num_classes))
    class_total = list(0. for i in range(num_classes))

    # For confusion matrix
    all_preds = []
    all_targets = []

    # For ROC curve
    all_scores = []

    # For sample-level analysis
    test_samples = []

    test_start_time = time.time()

    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Testing"):
            inputs, targets = inputs.to(device), targets.to(device)

            # Converting inputs to spectrograms if needed
            if inputs.dim() == 2:  # [batch, time]
                specs = []
                for waveform in inputs:
                     # Pass the device to extract_features
                    spec = extract_features(waveform, device=device)
                    specs.append(spec)
                inputs = torch.stack(specs)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Collectting statistics
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Per-class accuracy
            c = (predicted == targets).squeeze()
            # Check if c is a single boolean value (batch size 1) or a tensor
            if c.ndim == 0:
                 if targets.size(0) > 0: # Ensure batch is not empty
                     label = targets[0].item()
                     class_correct[label] += int(c.item()) # Convert boolean tensor to int
                     class_total[label] += 1
            else: # c is a tensor for batch size > 1
                for i in range(targets.size(0)):
                    label = targets[i].item()
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

            # Collect predictions and targets for confusion matrix
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

            # Collect scores for ROC curve
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            all_scores.extend(probs)

            # Sample-level analysis
            for i in range(len(targets)):
                sample_info = {
                    'true_label': int(targets[i].item()),
                    'pred_label': int(predicted[i].item()),
                    'correct': bool(predicted[i] == targets[i]),
                    'probabilities': {lang: float(probs[i][j]) for j, lang in enumerate(languages)}
                }
                test_samples.append(sample_info)

    test_time = time.time() - test_start_time

    # Calculating overall metrics
    test_loss /= len(test_loader)
    test_acc = 100. * correct / total

    logger.info(f"Test evaluation completed in {test_time:.2f}s")
    logger.info(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

    # Calculating per-class metrics
    per_class_metrics = {}
    for i in range(num_classes):
        if class_total[i] > 0:
            class_acc = 100 * class_correct[i] / class_total[i]
            logger.info(f'Accuracy of {languages[i]}: {class_acc:.2f}%')
            per_class_metrics[languages[i]] = {
                'accuracy': class_acc,
                'samples': class_total[i]
            }

    # Creating and saving all test visualizations and metrics
    if save_dir:
        # Converting lists to numpy arrays for easier processing
        all_targets = np.array(all_targets)
        all_preds = np.array(all_preds)
        all_scores = np.array(all_scores)

        # 1. Confusion matrix
        cm = confusion_matrix(all_targets, all_preds)
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
                    xticklabels=languages, yticklabels=languages)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Normalized Confusion Matrix')
        plt.savefig(os.path.join(test_dir, 'confusion_matrix.png'))
        plt.close()

        # raw confusion matrix data
        cm_df = pd.DataFrame(cm, index=languages, columns=languages)
        cm_df.to_csv(os.path.join(test_dir, 'confusion_matrix.csv'))

        # 2. Classification report with precision, recall, F1
        precision, recall, f1, support = precision_recall_fscore_support(all_targets, all_preds)
        metrics_df = pd.DataFrame({
            'Language': languages,
            'Precision': precision,
            'Recall': recall,
            'F1 Score': f1,
            'Support': support
        })
        metrics_df.to_csv(os.path.join(test_dir, 'classification_metrics.csv'), index=False)

        # 3. ROC curve and AUC
        # Binarize labels for ROC curve
        y_bin = label_binarize(all_targets, classes=range(len(languages)))

        plt.figure(figsize=(10, 8))
        for i, language in enumerate(languages):
            fpr, tpr, _ = roc_curve(y_bin[:, i], all_scores[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f'{language} (AUC = {roc_auc:.2f})')

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curves')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(test_dir, 'roc_curves.png'))
        plt.close()

        # 4. Distribution of prediction probabilities
        plt.figure(figsize=(12, 8))
        for i, language in enumerate(languages):
            # softmax scores for samples of this language
            class_indices = np.where(all_targets == i)[0]
            if len(class_indices) > 0:
                correct_scores = [all_scores[j][i] for j in class_indices if all_preds[j] == i]
                incorrect_scores = [all_scores[j][i] for j in class_indices if all_preds[j] != i]

                plt.subplot(1, len(languages), i+1)
                if correct_scores:
                    sns.kdeplot(correct_scores, fill=True, label='Correct', alpha=0.7)
                if incorrect_scores:
                    sns.kdeplot(incorrect_scores, fill=True, label='Incorrect', alpha=0.7)
                plt.title(f'{language}')
                plt.xlabel('Confidence')
                plt.ylabel('Density')
                plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(test_dir, 'prediction_distributions.png'))
        plt.close()

        # 5. Sample-level analysis
        samples_df = pd.DataFrame(test_samples)
        samples_df.to_csv(os.path.join(test_dir, 'sample_predictions.csv'), index=False)

        # 6. Misclassification analysis
        misclassified = [(true, pred) for true, pred in zip(all_targets, all_preds) if true != pred]
        misclass_counts = {}
        for true, pred in misclassified:
            pair = (languages[true], languages[pred])
            misclass_counts[pair] = misclass_counts.get(pair, 0) + 1

        misclass_df = pd.DataFrame([
            {'True': true, 'Predicted': pred, 'Count': count}
            for (true, pred), count in misclass_counts.items()
        ])

        if not misclass_df.empty:
            misclass_df = misclass_df.sort_values('Count', ascending=False)
            misclass_df.to_csv(os.path.join(test_dir, 'misclassification_analysis.csv'), index=False)

            # Top misclassifications plot
            plt.figure(figsize=(12, 8))
            top_n = min(10, len(misclass_df))
            top_misclass = misclass_df.head(top_n)

            sns.barplot(x='Count', y=top_misclass['True'] + ' → ' + top_misclass['Predicted'], data=top_misclass)
            plt.title(f'Top {top_n} Misclassifications')
            plt.xlabel('Count')
            plt.ylabel('Misclassification')
            plt.tight_layout()
            plt.savefig(os.path.join(test_dir, 'top_misclassifications.png'))
            plt.close()

    # summary report
    summary = {
        'test_acc': test_acc,
        'test_loss': test_loss,
        'per_class_metrics': per_class_metrics,
        'test_time': test_time,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }

    with open(os.path.join(test_dir, 'test_summary.json'), 'w') as f:
        json.dump(summary, f, indent=4)

    logger.info(f"Test evaluation completed. Results saved to {test_dir}")

    return test_loss, test_acc, per_class_metrics

In [13]:
# Function for splitting dataset into train, validation, and test sets
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    Split dataset into train, validation, and test sets

    Args:
        dataset: Dataset to split
        train_ratio: Ratio of training set
        val_ratio: Ratio of validation set
        test_ratio: Ratio of test set
        seed: Random seed for reproducibility

    Returns:
        train_set, val_set, test_set: Split datasets
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5, "Ratios must sum to 1"

    # Setting seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Getting dataset size
    dataset_size = len(dataset)

    # Calculating split sizes
    train_size = int(train_ratio * dataset_size)
    val_size = int(val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size

    # Splitting dataset
    train_set, val_set, test_set = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size])

    return train_set, val_set, test_set


# Function for saving model as ONNX for deployment
def save_model_as_onnx(model, save_path, input_shape=(1, 1, 80, 188)):
    """
    Save PyTorch model as ONNX for deployment

    Args:
        model: PyTorch model
        save_path: Path to save ONNX model
        input_shape: Input shape for ONNX export
    """
    # Creating dummy input
    dummy_input = torch.randn(input_shape, requires_grad=True)

    # Exporting model
    torch.onnx.export(
        model,                  # model being run
        dummy_input,            # model input
        save_path,              # where to save the model
        export_params=True,     # storing the trained parameter weights inside the model file
        opset_version=12,       # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=['input'],  # the model's input names
        output_names=['output'],  # the model's output names
        dynamic_axes={
            'input': {0: 'batch_size'},  # variable length axes
            'output': {0: 'batch_size'}
        }
    )

    print(f"Model exported to ONNX format at {save_path}")


In [14]:
waveform = torch.randn(1, 16000)  # Example: 1-second audio at 16kHz
features = extract_features(waveform, sample_rate=16000, n_mels=80, device='cpu')
print(features.shape)


torch.Size([1, 80, 101])


In [15]:
def main():
    # Import torchsummary for model architecture overview


    # Defining paths
    DRIVE_PATH = '/content/drive/MyDrive'

    # Use the correct path where the audio data was extracted
    EXTRACTED_DATA_PATH = '/content/language_data'

    print(f"Using audio data path: {EXTRACTED_DATA_PATH}")
    print(f"Checking for language directories:")

    # Verify the language directories exist
    for lang in ['hindi', 'english', 'chinese']:
        lang_path = os.path.join(EXTRACTED_DATA_PATH, lang)
        if os.path.exists(lang_path):
            print(f"✓ Found {lang} directory with {len(os.listdir(lang_path))} items")
        else:
            print(f"✗ Missing {lang} directory!")

    SAVE_DIR = os.path.join(DRIVE_PATH, 'LID_Training', f'lid_model_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
    os.makedirs(SAVE_DIR, exist_ok=True)

    # Setting random seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)

    # Configuration
    config = {
        'data_dir': EXTRACTED_DATA_PATH,  # Use the correct extracted data path
        'languages': ['hindi', 'english', 'chinese'],  # Languages to classify
        'samples_per_lang': 15000,  # Maximum samples per language
        'segment_length': 3,  # Audio segment length in seconds
        'batch_size': 16,
        'num_epochs': 30,
        'learning_rate': 0.001,
        'weight_decay': 1e-5,
        'dropout_rate': 0.5,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'save_dir': SAVE_DIR,
        'apply_vad': True,  # Whether to apply voice activity detection
        'resume_from': None,  # Path to checkpoint file for resuming training, None to start fresh
    }

    # Creating save directory
    os.makedirs(config['save_dir'], exist_ok=True)

    # Saving configuration
    with open(os.path.join(config['save_dir'], 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)

    # Setting-up logging
    logger = setup_logging(config['save_dir'])
    logger.info(f"Starting Language Identification model training with config: {config}")
    logger.info(f"Using device: {config['device']}")

    # Loading dataset
    logger.info("Loading dataset...")
    full_dataset = LanguageAudioDataset(
        root_dir=config['data_dir'],
        languages=config['languages'],
        max_samples_per_lang=config['samples_per_lang'],
        segment_length=config['segment_length'],
        apply_vad=config['apply_vad']
    )

    # Check if dataset is empty
    if len(full_dataset) == 0:
        logger.error(f"Dataset is empty! Please check the path: {config['data_dir']}")
        logger.info("Available directories:")
        if os.path.exists(config['data_dir']):
            logger.info(str(os.listdir(config['data_dir'])))
        else:
            logger.info(f"Directory {config['data_dir']} does not exist!")
        return None, 0, None

    logger.info(f"Dataset loaded successfully with {len(full_dataset)} samples")

    # Splitting dataset
    logger.info("Splitting dataset into train, validation, and test sets...")
    train_size = int(0.8 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size]
    )

    logger.info(f"Dataset split - Train: {len(train_dataset)}, "
               f"Validation: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Creating data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Creating model
    logger.info("Creating model...")
    model = LIDModel(num_languages=len(config['languages']))
    model = model.to(config['device'])

    # Use torchsummary to print model architecture
    try:

        # For spectrograms, something like (1, 80, 101) for (channels, height, width)
        input_size = (1, 80, 101)  # Update with your actual spectrogram dimensions
        logger.info("Generating model summary with torchsummary:")
        model_summary = summary(model, input_size, device=config['device'])

        # Save the model summary to a file
        summary_path = os.path.join(config['save_dir'], 'model_summary.txt')
        with open(summary_path, 'w') as f:
            # Redirect stdout to the file
            import sys
            original_stdout = sys.stdout
            sys.stdout = f
            summary(model, input_size, device=config['device'])
            sys.stdout = original_stdout

        logger.info(f"Model summary saved to {summary_path}")
    except Exception as e:
        logger.warning(f"Failed to generate model summary with torchsummary: {e}")
        logger.info(f"Model architecture:\n{model}")

    # Defining loss function, optimizer, and scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )

    # Check if we need to resume training from a checkpoint
    start_epoch = 0
    resume_training_flag = False

    if config['resume_from'] is not None and os.path.exists(config['resume_from']):
        start_epoch, model, optimizer, scheduler = resume_training(
            config['resume_from'], model, optimizer, scheduler, logger
        )
        resume_training_flag = True

    # Training model
    logger.info("Starting model training...")
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=config['num_epochs'],
        device=config['device'],
        save_dir=config['save_dir'],
        logger=logger,
        start_epoch=start_epoch,
        resume_training=resume_training_flag
    )

    # Evaluating model on test set
    logger.info("Evaluating model on test set...")
    test_loss, test_acc, per_class_metrics = test_model(
        model=model,
        test_loader=test_loader,
        criterion=criterion,
        device=config['device'],
        languages=config['languages'],
        save_dir=config['save_dir'],
        logger=logger
    )


    # Saving model in ONNX format for deployment
    logger.info("Saving model in ONNX format...")
    save_model_as_onnx(model, os.path.join(config['save_dir'], 'lid_model.onnx'))

    # Creating ZIP archive of the entire experiment
    logger.info("Creating ZIP archive of experiment results...")
    zip_path = f"{config['save_dir']}.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for root, dirs, files in os.walk(config['save_dir']):
            for file in files:
                zipf.write(
                    os.path.join(root, file),
                    os.path.relpath(os.path.join(root, file), os.path.join(config['save_dir'], '..'))
                )

    logger.info(f"Experiment results archived to {zip_path}")
    logger.info("Language Identification model training and evaluation completed!")

    return history, test_acc, per_class_metrics

if __name__ == "__main__":
    main()

Using audio data path: /content/language_data
Checking for language directories:
✓ Found hindi directory with 1 items
✓ Found english directory with 1 items
✓ Found chinese directory with 1 items
Found 15000 files for hindi
Found 15000 files for english
Found 15000 files for chinese
Total samples: 45000
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1             [-1, 101, 144]          11,664
PositionalEncoding-2             [-1, 101, 144]               0
           Dropout-3             [-1, 101, 144]               0
            Linear-4             [-1, 101, 256]          37,120
           Dropout-5             [-1, 101, 256]               0
            Linear-6             [-1, 101, 144]          37,008
       FeedForward-7             [-1, 101, 144]               0
         LayerNorm-8             [-1, 101, 144]             288
            Linear-9             [-1, 101, 144]       

Train Epoch 1: 100%|██████████| 2250/2250 [04:40<00:00,  8.01it/s, loss=0.4884, acc=67.47%]
Val Epoch 1: 100%|██████████| 282/282 [00:30<00:00,  9.19it/s, loss=0.7795, acc=77.69%]
Train Epoch 2: 100%|██████████| 2250/2250 [04:14<00:00,  8.85it/s, loss=0.7226, acc=79.76%]
Val Epoch 2: 100%|██████████| 282/282 [00:30<00:00,  9.29it/s, loss=0.3900, acc=69.47%]
Train Epoch 3: 100%|██████████| 2250/2250 [04:12<00:00,  8.91it/s, loss=0.5515, acc=83.34%]
Val Epoch 3: 100%|██████████| 282/282 [00:29<00:00,  9.50it/s, loss=0.2319, acc=82.67%]
Train Epoch 4: 100%|██████████| 2250/2250 [04:14<00:00,  8.85it/s, loss=0.2112, acc=85.54%]
Val Epoch 4: 100%|██████████| 282/282 [00:31<00:00,  8.95it/s, loss=0.0213, acc=87.64%]
Train Epoch 5: 100%|██████████| 2250/2250 [04:18<00:00,  8.72it/s, loss=0.6636, acc=87.29%]
Val Epoch 5: 100%|██████████| 282/282 [00:30<00:00,  9.24it/s, loss=0.3380, acc=89.64%]
Train Epoch 6: 100%|██████████| 2250/2250 [04:15<00:00,  8.81it/s, loss=0.9272, acc=88.52%]
Val Epoc

NameError: name 'plot_training_history' is not defined