In [8]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram, Resample

import glob
import os

# Define the Feature Extractor
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=11, stride=2, padding=5)
        self.bn1 = nn.BatchNorm1d(32)
        self.hswish = nn.Hardswish()
        self.se_module = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(32, 8, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(8, 32, kernel_size=1),
            nn.Hardsigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.hswish(x)
        x = self.se_module(x) * x
        return x

# Define the Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(64),
            nn.Hardswish()
        )
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(),
                nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU()
            ) for _ in range(3)
        ])
        self.out_layer = nn.Conv1d(64, 128, kernel_size=1)

    def forward(self, x):
        x = self.downsample(x)
        for conv in self.conv_layers:
            x = conv(x) + x  # Residual connection
        x = self.out_layer(x)
        return x

# Define the CTC Projector
class CTCProjector(nn.Module):
    def __init__(self, num_classes):
        super(CTCProjector, self).__init__()
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.permute(2, 0, 1)  # (batch, channels, time) -> (time, batch, channels)
        x = self.fc(x)
        return x

# Define the ASR Model
class ASRModel(nn.Module):
    def __init__(self, num_classes):
        super(ASRModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.encoder = Encoder()
        self.ctc_projector = CTCProjector(num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.encoder(x)
        x = self.ctc_projector(x)
        return x

# Dataset and DataLoader
class LibriSpeechDataset(Dataset):
    def __init__(self, root_dir, sample_rate=16000, n_mels=128):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.file_paths = []
        self.labels = {}
        self.resample = Resample(orig_freq=sample_rate, new_freq=sample_rate)
        self.mel_spectrogram = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        self.load_dataset()

    def load_dataset(self):
        for transcript_path in glob.glob(os.path.join(self.root_dir, '**', '*.trans.txt'), recursive=True):
            with open(transcript_path, 'r') as f:
                for line in f:
                    parts = line.strip().split(' ', 1)
                    file_id = parts[0]
                    transcript = parts[1]
                    audio_path = os.path.join(os.path.dirname(transcript_path), file_id + '.flac')
                    if os.path.exists(audio_path):
                        self.file_paths.append(audio_path)
                        self.labels[audio_path] = transcript
                    else:
                        print(f"Audio file {audio_path} not found, skipping.")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        try:
            waveform, sample_rate = torchaudio.load(file_path)
            if sample_rate != self.sample_rate:
                waveform = self.resample(waveform)
            mel_spec = self.mel_spectrogram(waveform).squeeze(0)
            label = self.labels[file_path]
            label = self.text_to_labels(label)
            return mel_spec, label
        except Exception as e:
            print(f"Error loading file {file_path}: {e}")
            return None, None

    def text_to_labels(self, text):
        return torch.tensor([ord(char) - ord('a') + 1 for char in text.lower() if char.isalpha()])

# Define the collate function
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    if len(batch) == 0:
        return None, None, None, None
    
    mel_specs, labels = zip(*batch)
    
    # Determine the maximum length of mel spectrograms
    max_length = max(mel_spec.size(1) for mel_spec in mel_specs)
    
    # Truncate or pad mel spectrograms to the maximum length
    padded_mel_specs = []
    for mel_spec in mel_specs:
        if (mel_spec.size(1) > max_length):
            mel_spec = mel_spec[:, :max_length]
        else:
            pad_size = max_length - mel_spec.size(1)
            mel_spec = F.pad(mel_spec, (0, pad_size), "constant", 0)
        padded_mel_specs.append(mel_spec)
    
    mel_specs = torch.stack(padded_mel_specs, dim=0)
    
    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.int32)
    labels = torch.cat(labels)
    
    # Compute the actual lengths of the mel spectrograms after padding
    mel_spec_lengths = torch.tensor([mel_spec.size(1) for mel_spec in padded_mel_specs], dtype=torch.int32)
    
    return mel_specs, labels, mel_spec_lengths, label_lengths

# DataLoader
train_dataset = LibriSpeechDataset(root_dir="../audio_datasets/train/LibriSpeech/")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

val_dataset = LibriSpeechDataset(root_dir="../audio_datasets/val/LibriSpeech/")
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Training and evaluation functions with debugging
# Training function
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    for mel_specs, labels, mel_spec_lengths, label_lengths in train_loader:
        mel_specs, labels = mel_specs.to(device), labels.to(device)
        mel_spec_lengths, label_lengths = mel_spec_lengths.to(device), label_lengths.to(device)
        
        optimizer.zero_grad()
        outputs = model(mel_specs)
        outputs = outputs.log_softmax(2)
        outputs = outputs.permute(1, 0, 2)  # (batch, time, class) -> (time, batch, class)
        
        # Calculate the length of the sequence over time for each sample in the batch
        input_lengths = mel_spec_lengths // 4  # Assuming downsampling factor of 4
        
        # Debugging prints
        print(f'outputs.shape: {outputs.shape}')
        print(f'labels.shape: {labels.shape}')
        print(f'input_lengths: {input_lengths}')
        print(f'label_lengths: {label_lengths}')
        
        loss = criterion(outputs, labels, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(train_loader)


# Evaluation function
def evaluate(model, val_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            if batch[0] is None:
                continue  # Skip batches where all items are invalid
            
            mel_specs, labels, mel_spec_lengths, label_lengths = batch
            mel_specs, labels = mel_specs.to(device), labels.to(device)
            mel_spec_lengths, label_lengths = mel_spec_lengths.to(device), label_lengths.to(device)
            
            outputs = model(mel_specs)
            outputs = outputs.log_softmax(2)
            outputs = outputs.permute(1, 0, 2)  # (batch, time, class) -> (time, batch, class)
            
            # Calculate the length of the sequence over time for each sample in the batch
            input_lengths = mel_spec_lengths
            
            # Debugging prints
            print(f'outputs.shape: {outputs.shape}')
            print(f'labels.shape: {labels.shape}')
            print(f'input_lengths: {input_lengths}')
            print(f'label_lengths: {label_lengths}')

            loss = criterion(outputs, labels, input_lengths, label_lengths)
            epoch_loss += loss.item()
    return epoch_loss / len(val_loader)

# Model Configuration
num_classes = len(" abcdefghijklmnopqrstuvwxyz'") + 1  # Adjust based on the dataset used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASRModel(num_classes).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss = evaluate(model, val_loader, criterion, device)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')


outputs.shape: torch.Size([32, 331, 29])
labels.shape: torch.Size([4745])
input_lengths: tensor([330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330,
        330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330, 330,
        330, 330, 330, 330], device='cuda:0', dtype=torch.int32)
label_lengths: tensor([190, 142,  31, 152, 124, 151, 171, 179, 190, 189,  57, 174, 165, 178,
        148, 144, 187, 134, 147,  62, 171, 131, 172, 218, 168, 190, 142, 105,
         75, 117, 150, 191], device='cuda:0', dtype=torch.int32)


RuntimeError: input_lengths must be of size batch_size