In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
import numpy as np
import time
from torch.nn.utils.rnn import *
import numpy as np
import os
from ctcdecode import CTCBeamDecoder
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from Levenshtein import *

In [2]:

class WSJ():
    """ Load the WSJ speech dataset
        
        Ensure WSJ_PATH is path to directory containing 
        all data files (.npy) provided on Kaggle.
        
        Example usage:
            loader = WSJ()
            trainX, trainY = loader.train
            assert(trainX.shape[0] == 24590)
            
    """
  
    def __init__(self):
        self.dev_set = None
        self.train_set = None
        self.test_set = None
  
    @property
    def dev(self):
        if self.dev_set is None:
            self.dev_set = load_raw(os.environ['WSJ_PATH'], 'dev')
        return self.dev_set

    @property
    def train(self):
        if self.train_set is None:
            self.train_set = load_raw(os.environ['WSJ_PATH'], 'train')
        return self.train_set
  
    @property
    def test(self):
        if self.test_set is None:
            self.test_set = (np.load(os.path.join(os.environ['WSJ_PATH'], 'test.npy'), encoding='bytes'), None)
        return self.test_set
    
def load_raw(path, name):
    return (
        np.load(os.path.join(path, '{}.npy'.format(name)), encoding='bytes'), 
        np.load(os.path.join(path, '{}_transcripts.npy'.format(name)), encoding='bytes')
    )
os.environ['WSJ_PATH'] = './'

In [3]:
loader = WSJ()
trainX, trainY = loader.train
testX,testY = loader.test
devX, devY = loader.dev

In [4]:
letter_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q',\
             'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '-', "'", '.', '_', '+', ' ','<sos>','<eos>']

In [5]:
class Dataset(Dataset):
    def __init__(self,trainX,trainY, train=True):
        self.trainY = trainY
        self.trainX = trainX
        
        self.Y = []
        if train:
            for i in range(len(self.trainY)):
                one = []
                for j in range(len(self.trainY[i])):
                    one.append(self.trainY[i][j].decode())
                self.Y.append(one)
            self.indexY = []
            for i in range(len(self.Y)):
                one = []
                one.append(letter_list.index('<sos>'))
                for j in range(len(self.Y[i])):
                    for m in range(len(self.Y[i][j])):
                        one.append(letter_list.index(self.Y[i][j][m]))
                    one.append(letter_list.index(' '))
                one.append(letter_list.index('<eos>'))
                self.indexY.append(one)
            
    def __len__(self):
        return len(self.indexY)
    
    def __getitem__(self,index):
        if not train:
            return self.trainX[index]
        return torch.from_numpy(self.trainX[index]),torch.tensor(self.indexY[index])
                                    
            

In [6]:
def collate_train(batch_data):
    inputs,targets = zip(*batch_data)
    uttlens = [len(seq) for seq in inputs]
    translens = [len(seq) for seq in targets]
    speech = rnn.pad_sequence(inputs)
    text = rnn.pad_sequence(targets)
    return speech.long(), text.long(), torch.LongTensor(uttlens), torch.LongTensor(translens)

def collate_test(batch_data):
    uttlens = [len(seq) for seq in batch_data]
    speech = rnn.pad_sequence(batch_data)
    return speech.long(), torch.LongTensor(uttlens)

In [7]:
Speech2Text_train_Dataset = Dataset(trainX, trainY)
Speech2Text_dev_Dataset = Dataset(devX,devY)

In [8]:
Speech2Text_test_Dataset = Dataset(testX,testY,False)

In [9]:
train_loader = DataLoader(Speech2Text_train_Dataset, batch_size=64, shuffle=True, collate_fn=collate_train)
dev_loader = DataLoader(Speech2Text_dev_Dataset, batch_size=64, shuffle=True, collate_fn=collate_train)

In [10]:
test_loader = DataLoader(Speech2Text_test_Dataset, batch_size=64,collate_fn=collate_test)

In [11]:
class Listener(nn.Module):
    def __init__(self, input_dim, hidden_dim, value_size,key_size,layers):
        super(Listener, self).__init__()
        self.rnn = nn.LSTM(input_size=input_dim,hidden_size=hidden_dim//2,num_layers=1,bidirectional=True)
        self.lstmList = nn.Sequential(nn.LSTM(input_size=input_dim,hidden_size=hidden_dim,num_layers=1,bidirectional=True),
                                     nn.LSTM(input_size=input_dim*4,hidden_size=hidden_dim,num_layers=1,bidirectional=True),
                                     nn.LSTM(input_size=input_dim*4,hidden_size=hidden_dim,num_layers=1,bidirectional=True))
        
        self.layers = layers
        self.key_network = nn.Linear(hidden_dim*4, value_size)
        self.value_network = nn.Linear(hidden_dim*4, key_size)
  
    def forward(self,x,lens):
      
        rnn_inp = x

        for l in range(self.layers):
            print(rnn_inp.type())
            
            outputs,_ = self.lstmList[l](rnn_inp)
            if len(outputs)%2 == 1:
                outputs = outputs[:-1,:,:]
            T,N,H= outputs.size()
            rnn_inp = outputs.transpose(0,1).reshape(N,T//2,H*2).transpose(0,1)
            
            
       
        keys = self.key_network(rnn_inp)
        value = self.value_network(rnn_inp)

        return keys, value

    

In [12]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
    
    def forward(self, key,value,query,lengths):
        energy = torch.bmm(key,query.unsqueeze(2)).squeeze(2) #(N,T)
        
        mask = torch.arange(context.size(1)).unsqueeze(0) >= (lengths//8).unsqueeze(1) #(N,T) boolean mask
        # Set attention logits at padding positions to negative infinity.
        energy.masked_fill_(mask, -1e9)
        
        attention = nn.functional.softmax(energy,dim = 1)
        
        context = torch.bmm(energy.unsqueeze(1),value).squezze(1)
        return context, attention


In [13]:
class Speller:
    def __init__(self, vocab_size, hidden_dim,value_size,key_size,isAttended = True):
        super(Speller, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
    
        self.lstm1 = nn.LSTMCell(input_size=hidden_dim+value_size, hidden_size=hidden_dim)
        self.lstm2 = nn.LSTMCell(input_size=hidden_dim, hidden_size=key_size)
        self.isAttended = isAttended
        if(isAttended):
            self.attention = Attention()
        self.character_prob = nn.Linear(key_size+value_size,vocab_size)

    def forward(self, key, values,lengths ,text=None, train=True):
        batch_size = key.shape[1]
        if(train):
            max_len =  text.shape[1]
            embeddings = self.embedding(text)
        else:
            max_len = 250
    
        predictions = []
        hidden_states = [None, None]
        prediction = torch.zeros(batch_size,1).to(DEVICE)
        for i in range(max_len):
            
            
            if(train):
                char_embed = embeddings[:,i,:]
            else:
                char_embed = self.embedding(prediction.argmax(dim=-1))
      

            if self.isAttended:
                context1, attention1 = self.attention(key.transpose(0,1),value.transpose(0,1), char_embed,lengths)
            
            inp = torch.cat([char_embed,context[i,:,:]], dim=1)
            hidden_states[0] = self.lstm1(inp,hidden_states[0])
    
            inp_2 = hidden_states[0][0]
            hidden_states[1] = self.lstm2(inp_2,hidden_states[1])
            output = hidden_states[1][0]
            if self.isAttended:
                context2, attention2 = self.attention(key.transpose(0,1),value.transpose(0,1),output,lengths)
            prediction = self.character_prob(torch.cat([output, context[i,:,:]], dim=1))
            predictions.append(prediction.unsqueeze(1))

        return torch.cat(predictions, dim=1)
    

In [14]:
class Seq2Seq(nn.Module):
    def __init__(self,input_dim,vocab_size,hidden_dim,value_size=128, key_size=128,Elayers = 3,isAttended=False):
        super(Seq2Seq,self).__init__()

        self.encoder = Listener(input_dim, hidden_dim,value_size,key_size,Elayers)
        self.decoder = Speller(vocab_size, hidden_dim,value_size,key_size)
        
    def forward(self,speech_input, speech_len, text_input=None,train=True):
        
        key, value = self.encoder(speech_input, speech_len)
        if(train):
            predictions = self.decoder(key, value, text_input)
        else:
            predictions = self.decoder(key, value, text=None, train=False)
        return predictions

In [15]:
def train(model,train_loader, num_epochs, criterion, optimizer):
    for epochs in range(num_epochs):
        loss_sum = 0
        since = time.time()
        for (batch_num, collate_output) in enumerate(train_loader):
            with torch.autograd.set_detect_anomaly(True):
        
                speech_input, text_input, speech_len, text_len = collate_output
                speech_input = speech_input.to(DEVICE)
                text_input = text_input.to(DEVICE)
                speech_len = speech_len.to(DEVICE)
                text_len = text_len.to(DEVICE)

                predictions = model(speech_input, speech_len ,text_input)
                mask = torch.zeros(text_input.size()).to(DEVICE)

                for length in text_len:
                    mask[:,:length] = 1
        
                mask = mask.view(-1).to(DEVICE)
        

                predictions = predictions.contiguous().view(-1, predictions.size(-1))
                text_input = text_input.contiguous().view(-1)

                loss = criterion(predictions, text_input)
                masked_loss = torch.sum(loss*mask)

                masked_loss.backward()

                torch.nn.utils.clip_grad_norm(model.parameters(), 2)
                optimizer.step()

                current_loss = float(masked_loss.item())/int(torch.sum(mask).item())

                if  batch_num % 25 == 1:
                    print('train_loss', current_loss)
                    
        
#         for (batch_num, collate_output) in enumerate(valid_loader):
#             with torch.autograd.set_detect_anomaly(True):
        
#                 speech_input, text_input, speech_len, text_len = collate_output
                    
        
            

In [16]:
model = Seq2Seq(40,len(letter_list),hidden_dim = 128)

In [17]:
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(reduce=None).to(DEVICE)
train(model,train_loader, 1, criterion, optimizer)

torch.cuda.LongTensor


RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2'

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.plot(loss_history)
plt.show()

for i in range(X.size(1)):
    fig, ax = plt.subplots(figsize=(X_lens[i] * 0.5, (Y_lens[i] -  1) * 0.5))
    ax.imshow(all_attentions[i, :Y_lens[i] - 1, :X_lens[i]].numpy())
    ax.set_xticks(np.arange(X_lens[i]))
    ax.set_yticks(np.arange(Y_lens[i] - 1))
    ax.set_xticklabels(list(data[i][0]))
    ax.set_yticklabels(data[i][1].split() + ['</s>'])
    ax.set_ylim(Y_lens[i]-1.5, -0.5)
    plt.show()