In [1]:
import os
import torch
import torchaudio
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import pandas as pd

In [2]:


# Paths
CSV_PATH = 'geo/train.csv'
DATA_DIR = 'geo/clips'
SAMPLE_RATE = 16000  # Assuming 16kHz; adjust if needed
N_MELS = 80
HOP_LENGTH = 160  # 10ms hop
WIN_LENGTH = 400  # 25ms window

# Vocabulary: Esperanto characters + specials (adjust as needed for gibberish)
CHAR_MAP = {c: i+3 for i, c in enumerate('abcĉdefgĝhĥijĵklmnoprsŝtuŭvz ')}  # Lowercase, add accents and space
CHAR_MAP['<PAD>'] = 0
CHAR_MAP['<SOS>'] = 1
CHAR_MAP['<EOS>'] = 2
VOCAB_SIZE = len(CHAR_MAP)
INV_CHAR_MAP = {v: k for k, v in CHAR_MAP.items()}

class AudioDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)  # Assumes columns: 'file', 'transcript'

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['file']
        transcript = self.df.iloc[idx]['transcript'].strip().lower()

        waveform, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)

        mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE, n_mels=N_MELS, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, n_fft=WIN_LENGTH
        )
        mel_spec = torch.log(mel_transform(waveform) + 1e-9)  # Log-Mel
        mel_spec = mel_spec.squeeze(0).transpose(0, 1)  # (seq_len, n_mels)

        target = [CHAR_MAP['<SOS>']] + [CHAR_MAP[c] for c in transcript if c in CHAR_MAP] + [CHAR_MAP['<EOS>']]
        target = torch.tensor(target, dtype=torch.long)

        return mel_spec, target

def collate_fn(batch):
    mels, targets = zip(*batch)
    mel_lens = torch.tensor([len(m) for m in mels])
    target_lens = torch.tensor([len(t) for t in targets])
    mels_padded = pad_sequence(mels, batch_first=True, padding_value=0)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=CHAR_MAP['<PAD>'])
    return mels_padded, targets_padded, mel_lens, target_lens

# Model components
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[..., None] * emb[None, :]  # Broadcasting fix
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class AudioEncoder(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=256, num_layers=6, num_heads=4):
        super().__init__()
        self.conv_sub = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dim_feedforward=1024, dropout=0.1, batch_first=True)
        self.transformer_enc = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_emb = SinusoidalPosEmb(hidden_dim)

    def forward(self, x, lengths):
        # x: (batch, seq_len, 80)
        x = x.transpose(1, 2)  # For conv1d: (batch, 80, seq_len)
        x = self.conv_sub(x)  # (batch, 256, seq_len//4)
        x = x.transpose(1, 2)  # (batch, seq_len//4, 256)
        seq_len = x.size(1)
        pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0).repeat(x.size(0), 1)
        x = x + self.pos_emb(pos)
        mask = torch.arange(seq_len, device=x.device)[None, :] >= (lengths // 4)[:, None]
        x = self.transformer_enc(x, src_key_padding_mask=mask)
        return x, mask

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.scale = hidden_dim ** -0.5

    def forward(self, query, keys, values, mask=None):
        q = self.query_proj(query)
        k = self.key_proj(keys)
        v = self.value_proj(values)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1), -1e9)
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        return context

class RecurrentDecoder(nn.Module):
    def __init__(self, embed_dim=256, hidden_dim=512, enc_dim=256, vocab_size=VOCAB_SIZE, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=CHAR_MAP['<PAD>'])
        self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=0.1)
        self.memory_proj = nn.Linear(enc_dim, hidden_dim)  # New: Project encoder dim to decoder hidden_dim
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)  # Concat rnn_out + context
        self.pos_emb = SinusoidalPosEmb(embed_dim)

    def forward(self, tgt, memory, tgt_lengths, memory_mask):
        # tgt: (batch, tgt_len)
        tgt_emb = self.embedding(tgt)
        tgt_seq_len = tgt.size(1)
        pos = torch.arange(0, tgt_seq_len, device=tgt.device).unsqueeze(0).repeat(tgt.size(0), 1)
        tgt_emb = tgt_emb + self.pos_emb(pos)
        packed_tgt = pack_padded_sequence(tgt_emb, tgt_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.rnn(packed_tgt)
        rnn_out, _ = pad_packed_sequence(packed_out, batch_first=True)
        
        # Project memory to match decoder dim
        memory_proj = self.memory_proj(memory)
        
        # Attention: query=rnn_out, key=value=memory_proj
        context = self.attention(rnn_out, memory_proj, memory_proj, memory_mask)
        
        combined = torch.cat((rnn_out, context), dim=-1)
        logits = self.fc(combined)
        return logits

class ASRModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AudioEncoder()
        self.decoder = RecurrentDecoder()  # Defaults to enc_dim=256

    def forward(self, src, tgt, src_lengths, tgt_lengths):
        enc_out, enc_mask = self.encoder(src, src_lengths)
        logits = self.decoder(tgt, enc_out, tgt_lengths, enc_mask)
        return logits


In [3]:
# Training function

dataset = AudioDataset(CSV_PATH)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

model = ASRModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=CHAR_MAP['<PAD>'])
print(model)

ASRModel(
  (encoder): AudioEncoder(
    (conv_sub): Sequential(
      (0): Conv1d(80, 256, kernel_size=(3,), stride=(2,), padding=(1,))
      (1): ReLU()
      (2): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,))
      (3): ReLU()
    )
    (transformer_enc): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=1024, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
 

In [None]:
num_epochs = 50  # Adjust based on your dataset size and convergence
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for mels, targets, mel_lens, target_lens in dataloader:
        optimizer.zero_grad()
        # Teacher forcing: input_tgt = targets[:, :-1], label = targets[:, 1:]
        input_tgt = targets[:, :-1]
        label = targets[:, 1:]
        input_tgt_lens = target_lens - 1

        logits = model(mels, input_tgt, mel_lens, input_tgt_lens)
        # logits: (batch, tgt_len, vocab), label: (batch, tgt_len)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), label.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}')

torch.save(model.state_dict(), 'asr_model.pth')