In [1]:
import argparse
import os
import math
import random
import json
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from encoder import Encoder
from attention import Attention
from decoder import Decoder
from generator import Generator

from data_loader import SpeechDataset, Padding, ToTensor

Using TensorFlow backend.


In [2]:
with open('labels_dict.json', 'r') as f:
    labels = json.loads(f.read())
    
id2label = {v: k for k, v in labels.items()}

In [3]:
SIGNAL_SEQ_LEN = 1100 
TXT_SEQ_LEN = 189
OUTPUT_DIM = len(labels)
BATCH_SIZE = 12

audio_conf = {'window': 'hamming',
              'window_size' : 0.02,
              'window_stride' : 0.01,
              'sampling_rate': 16000}

val_dataset = SpeechDataset('val_manifest.csv', 
                            'labels_dict.json',
                            audio_conf,
                            transform=transforms.Compose([Padding(SIGNAL_SEQ_LEN, TXT_SEQ_LEN, 'labels_dict.json')]) 
                              )

In [None]:
torch.cuda.set_device(1)

In [12]:
SIGNAL_FEATURE = 161
NUM_GRU = 4
ENC_HID_DIM = 256
DEC_HID_DIM = 256 
DEC_EMB_DIM = 256
DROPOUT_RATE = 0.2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

encoder = Encoder(seq_len=SIGNAL_SEQ_LEN, input_size=SIGNAL_FEATURE, 
                  enc_hid_dim=ENC_HID_DIM, num_gru=NUM_GRU, 
                  dec_hid_dim=DEC_HID_DIM, dropout_rate=DROPOUT_RATE, 
                  device=device, use_pooling=False)

attention = Attention(enc_hid_dim=ENC_HID_DIM, dec_hid_dim=DEC_HID_DIM)

decoder = Decoder(output_dim=OUTPUT_DIM, emb_dim=DEC_EMB_DIM, 
                  enc_hid_dim=ENC_HID_DIM, dec_hid_dim=DEC_HID_DIM,
                  dropout_rate=DROPOUT_RATE, attention=attention)

model = Generator(encoder, decoder, device).to(device)
model = model.eval()

In [13]:
signal, txt_ids = torch.from_numpy(val_dataset[2]['signal']), torch.from_numpy(val_dataset[2]['transcript'])
signal = signal.type(torch.FloatTensor).to(device)
signal = signal.permute(1, 0)
txt_ids = txt_ids.type(torch.LongTensor).to(device)

In [14]:
signal.size(), txt_ids.size()

(torch.Size([1100, 161]), torch.Size([1, 189]))

In [15]:
model.load_state_dict(torch.load('models/rsr_gan.pt'))
model.eval()

Generator(
  (encoder): Encoder(
    (rnn_stack): ModuleList(
      (0): GRU(161, 256, batch_first=True, dropout=0.2, bidirectional=True)
      (1): GRU(256, 256, batch_first=True, dropout=0.2, bidirectional=True)
      (2): GRU(256, 256, batch_first=True, dropout=0.2, bidirectional=True)
      (3): GRU(256, 256, batch_first=True, dropout=0.2, bidirectional=True)
    )
    (batch_norm_stack): ModuleList(
      (0): BatchNorm1d(281600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(281600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(281600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm1d(281600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fc): Linear(in_features=512, out_features=256, bias=True)
    (pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Decoder(
    (attention): Att

In [16]:
_, _, output = model(signal.view(1, 1100, 161), txt_ids)

In [17]:
output.size()

torch.Size([1, 189, 32])

In [18]:
output.max(2)[1].size()

torch.Size([1, 189])

In [19]:
out_txt = ''
for l in output.max(2)[1][0].cpu().numpy():
    out_txt += id2label[l] 
    
out_txt

' heenenoueuennouenunounonunuunuenunenouunuunuunununununuuununuunuunuuununuuuunuunuuuuuuununuuunuunununuuuunuuununununuuuuununuuunununununuununuuunuunuunuuuuuunuunununununuuuuunununuunuuuuun'

In [20]:
true_txt = ''
for l in txt_ids[0].cpu().numpy():
    true_txt += id2label[l]
    
true_txt

'<sos>a hundred nuns stampeded the vatican<eos>padpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpadpad'