In [29]:
import torch
import torchaudio
import torch.nn as nn
from IPython.display import Audio
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from timit_utils import *

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:

class FrameClassifier(nn.Module):
    """
    Frame-wise phone classifier for the Timit dataset using RNNs with 29 phonemes
    """
    def __init__(self, num_phonemes=40, hidden_size=64):
        super(FrameClassifier, self).__init__()
        self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_phonemes)
        self.softmax = nn.Softmax(dim=1)
        self.num_phonemes = num_phonemes
    
    def forward(self, x):
        # Assuming x is of shape [batch_size, seq_len, input_size]
        out, _ = self.rnn(x)
        # Take the output corresponding to the last input of each sequence
        out = out[range(out.shape[0]), -1, :] # [B, T, H] -> [B, H]
        out = self.fc(out)
        return out

In [58]:
# training loop
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
dataset = Timit('timit')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
model = FrameClassifier()
optimizer = Adam(model.parameters())
criterion = CrossEntropyLoss()
epochs = 1
num_steps = len(dataloader)*epochs
epoch=0
with tqdm(range(num_steps)) as pbar:
    for step in pbar:
        data_path, wav, transcript, x, y = next(iter(dataloader))
        x = x.squeeze(0)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output.squeeze(0), y.squeeze(0))
        loss.backward()
        optimizer.step()
        # Report
        if step % 5 ==0 :
            loss = loss.detach().cpu()
            pbar.set_description(f"epoch={epoch}, step={step}, loss={loss:.1f}")

        if (step+1) % len(dataloader) == 0:
            epoch += 1

  0%|          | 0/160 [00:00<?, ?it/s]

epoch=0, step=155, loss=3.1: 100%|██████████| 160/160 [01:51<00:00,  1.43it/s]


In [32]:
output.argmax(0)

tensor([10, 25,  9,  0,  0,  4,  0,  2, 16,  2, 25, 12, 25,  0,  9, 25, 25, 26,
        21, 25, 22, 25,  2, 25, 13,  2, 25, 24, 13, 11, 25, 25, 25,  3, 22, 19,
         2, 25, 25, 25])

In [59]:
print([dataset.idx_to_phonemes[i.item()] for i in output.argmax(0)])

['k', 'aa', 's', 'k', 'dh', 'k', 'aa', 'hh', 'er', 'k', 'n', 'er', 'k', 'sh', 'k', 'h#', 'ah', 'k', 'n', 'k', 'k', 'l', 'y', 'k', 'k', 'k', 'er', 'aa', 'ah', 'h#', 'k', 'k', 'aa', 'k', 'ih', 'k', 'k', 'k', 'sh', 'k']


In [52]:
y.squeeze(0)

tensor([ 6, 13, 31, 35,  3,  4,  8, 13,  5,  6,  4, 15, 20,  0, 10,  4,  6, 21,
        20,  6,  0,  8,  2,  6, 22, 39, 35,  6])

In [56]:
print([dataset.idx_to_phonemes[i.item()] for i in y.squeeze(0)])


['h#', 'd', 'uh', 'z', 'hh', 'ih', 'n', 'd', 'uw', 'h#', 'ih', 'dx', 'iy', 'aa', 'l', 'ih', 'h#', 'jh', 'iy', 'h#', 'aa', 'n', 'er', 'h#', 'k', 'aw', 'z', 'h#']
