### Imports

In [1]:
import torch
import torch.nn as nn

from tqdm import tqdm

import numpy as np

from utils.cnn_gru import Decoder
from utils.dataset import Dataset

### Hyperparams

In [2]:
device = torch.device("cuda")

EPOCHS = 1
lr = 3e-4
weight_decay = 0

### Decoder initialization

In [5]:
decoder = Decoder(in_feature_dim=12, conv_kernel=2, conv_stride=2, hidden_dim=32, num_layers=3, vocab_size=43)
decoder = decoder.to(device)

# batch_size, temporal_dim, feature_dim
test_input = torch.randn((1, 479, 12)).to(device)

output = decoder(test_input)

output.shape

torch.Size([1, 239, 43])

### Dataloader initialization

In [6]:
torch.manual_seed(0)

dataset = Dataset("./data/cin_us_faet0/")

sample = next(iter(dataset))

sample[1]

Num samples: 460
Loading phoneme data...
Vocab length: 42


tensor([ 1,  1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  3, 12,  3, 13, 14,
         6,  5, 15, 16, 17, 14,  5, 18, 14,  3,  6, 19, 13, 20,  5, 13, 21,  5,
         6, 11, 22,  6,  5, 14, 11, 12, 19,  7, 10, 23, 24, 12,  2,  1])

In [7]:
train_set, val_set = torch.utils.data.random_split(dataset, [0.8, 0.2])

batch_size = 1

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

### Optimizer + loss initialization

In [8]:
optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)

# ctc blank label is 0
criterion = nn.CTCLoss()

### Training loop

In [9]:
num_steps = 20

for epoch in range(EPOCHS):
    # decoder train
    decoder.train()

    for ema_data, phone_sequences, seq_lengths in tqdm(train_loader):
        ema_data = ema_data.to(device)
        phone_sequences = phone_sequences.to(device)

        output = decoder(ema_data)

        # for ctc loss
        output = output.permute(1, 0, 2)

        loss = criterion(output, phone_sequences, torch.full((output.shape[1],), output.shape[0]), seq_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # decoder evaluation
    decoder.eval()

    train_losses = []
    for _, (ema_data, phone_sequences, seq_lengths) in enumerate(tqdm(train_loader)):
        ema_data = ema_data.to(device)
        phone_sequences = phone_sequences.to(device)

        output = decoder(ema_data)

        # for ctc loss
        output = output.permute(1, 0, 2)

        loss = criterion(output, phone_sequences, torch.full((output.shape[1],), output.shape[0]), seq_lengths)

        train_losses += [loss.item()]

        if _ == num_steps:
            print("Train loss: " + str(np.mean(np.array(train_losses))))
            break

    val_losses = []
    for ema_data, phone_sequences, seq_lengths in tqdm(val_loader):
        ema_data = ema_data.to(device)
        phone_sequences = phone_sequences.to(device)

        output = decoder(ema_data)

        # for ctc loss
        output = output.permute(1, 0, 2)

        loss = criterion(output, phone_sequences, torch.full((output.shape[1],), output.shape[0]), seq_lengths)

        val_losses += [loss.item()]
    print("Val loss: " + str(np.mean(np.array(val_losses))))

  0%|          | 0/368 [00:00<?, ?it/s]

100%|██████████| 368/368 [00:08<00:00, 45.08it/s]
  5%|▌         | 20/368 [00:00<00:00, 394.40it/s]


Train loss: 3.5007530621119907


100%|██████████| 92/92 [00:00<00:00, 367.50it/s]

Val loss: 3.4973837240882544



