In [1]:
import numpy as np
import time, os
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [2]:
X_train = np.load('all/train.npy', encoding='bytes')
Y_train = np.load('all/character/train_labels.npy')
X_dev = np.load('all/dev.npy', encoding='bytes')
Y_dev = np.load('all/character/dev_labels.npy')
vocab_map = np.load('all/character/vocab.npy')

### dataset class

In [3]:
class CharDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels
    def __getitem__(self,i):
        if self.labels != None:
            return torch.tensor(self.data[i]), torch.tensor(self.labels[i], dtype=torch.long)
        else:
            return torch.tensor(self.data[i])
    def __len__(self):
        return self.data.shape[0]


def collate(seq_list):
    """
    return: a batch sorted by decreasing order of length of sequences
    inputs: (L_padded, batch_size, 40)
    targets: list of targets
    """
    inputs,targets = zip(*seq_list)
    lens = [seq.shape[0] for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

#### Initial setting: listener layer:2, hidden: 128/direction;  speller layer: 1, hidden: 256
#### after that: listener layer:3, hidden: 256; speller: *, hidden: 512 

In [4]:
class ListenerModel(nn.Module):
    def __init__(self, hidden_size=256, key_size=128, value_size=128, embed_size=40, nlayers=4): # todo: nlayers=4
        
        super(ListenerModel, self).__init__()
        self.nlayers = nlayers
        self.hidden_size = hidden_size
        
        # todo: add conv layer
        
        self.rnns = nn.ModuleList([
            nn.LSTM(embed_size if i==0 else hidden_size*4, hidden_size, num_layers=1, bidirectional=True) for i in range(nlayers)
        ])
        
        self.fc1 = nn.Linear(hidden_size*2, key_size)
        self.fc2 = nn.Linear(hidden_size*2, value_size)
        self.weight_init()
    
    def weight_init(self):
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        self.fc1.bias.data.fill_(0.01)
        self.fc2.bias.data.fill_(0.01)
        
        
    def forward(self, inputs):
        """
        key: (L, N, key_size)
        value: (L, N, value_size)
        hidden: hidden state of last layer at t = L, (h_n, c_n); h_n, (num_direction,bs,hidden)
        """
        batch_size = len(inputs)
        lens = [len(s) for s in inputs] # actual lens of all sequences (already sorted) 
        # todo: add conv layer
        padded_input = rnn.pad_sequence(inputs).to(DEVICE)  # (L_padded, N, 40)      
        out_packed = rnn.pack_padded_sequence(padded_input, lens)
        for (i,l) in enumerate(self.rnns):
            out_packed, hidden = l(out_packed)
            out_padded, lens = rnn.pad_packed_sequence(out_packed)
            seq_len, batch_size, dim = out_padded.size()
            
            if i < self.nlayers-1:
                if seq_len % 2 == 1:
                    out_padded = out_padded[:seq_len-1,:,:]
                    seq_len, batch_size, dim = out_padded.size()
                out_padded = out_padded.permute(1,0,2).contiguous().view(batch_size, seq_len // 2, dim*2).permute(1,0,2)
                lens = [i // 2 for i in lens]
                out_packed = rnn.pack_padded_sequence(out_padded, lens)   # (L_padded, N, hidden_size)
        
        key = self.fc1(out_padded)
        value = self.fc2(out_padded)
        return key, value, hidden, lens

In [5]:
class LSTMCell_(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell_, self).__init__()
        self.lstmcell = nn.LSTMCell(input_size, hidden_size, bias)
        self.h_0 = nn.Parameter(torch.randn(1,hidden_size).to(DEVICE))
        self.c_0 = nn.Parameter(torch.randn(1,hidden_size).to(DEVICE))
        self.parameters = nn.ParameterList([self.h_0,self.c_0])
    
    def forward(self, inputs, t, hidden_states):
        if t == 0:
            hidden_states = (self.h_0.expand(inputs.size(0),-1), self.c_0.expand(inputs.size(0),-1))
        hidden_state = self.lstmcell(inputs, hidden_states)
        return hidden_state
        

In [6]:
class SpellerModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=512, embed_size=256, key_size=128, nlayers=2):
        super(SpellerModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.nlayers = nlayers
        self.vocab_size = vocab_size
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # projection os state s
        self.fc1 = nn.Linear(hidden_size, key_size)
        
        self.rnns = nn.ModuleList([
            LSTMCell_(embed_size+key_size if layer==0 else hidden_size, hidden_size) for layer in range(nlayers)
        ])
        

        self.scoring = nn.Linear(hidden_size+key_size, vocab_size)
    
    def forward(self, inputs, audio_lens, key, value, teacher=False, hidden_init=None):
        """
        inputs: (L2_padded, batch_size), L2_padded = padded transcript length 
        audio_lens: list of actual length of audio sequences
        key: (L_padded, bs, key_size)
        value: (L_padded, bs, value_size)
        query_init: (bs, hidden)
        context: (batch_size, context_size)
        outs: (L2_padded, bs, vocab_size)
        hiddens: a list of (h_n, c_n), n = L2_padded
        """
        L2_padded = inputs.size(0)     
        L_padded, N, _ = value.size()
        key = key.permute(1, 2, 0)        # -> (N,key_size, L_padded)
        value = value.permute(1, 0, 2)   # -> (N, L, value_size)
        
        embed = self.embed(inputs)   # (L2_padded, batch_size, embed_size)
        hiddens = [None] * self.nlayers
        outs = []
        
        # create mask for padded audio utterance
        mask = torch.ones(N, 1, L_padded).to(DEVICE) 
        for i in range(len(audio_lens)):
            mask[i,:,audio_lens[i]:] = 0
            
        # initial query
        s = self.rnns[-1].h_0.expand(N,-1)  # (N, key_size)
        query = self.fc1(s).unsqueeze(1)   # (N, 1, key_size)
        
        for i in range(embed.size(0)): 
            if (i!=0) and teacher:
                x = outs[-1]  # (N, vocab_size)
                x = F.softmax(x, dim=1) # (N, vocab_size)
                x = torch.argmax(x, dim=1) # (N, 1)
                y = self.embed(x)  # (N, embed_size)
            else:
                y = embed[i]    #(N, embed_size)
            # create context
            energy = torch.bmm(query, key)      #(N, 1, L_padded)
            attention = F.softmax(energy, 2)    #(N, 1, L_padded)
            attention = torch.mul(mask, attention)  
            # normalize
            attention = attention / (torch.sum(attention, dim=1)+1e-12).unsqueeze(1)
            context = torch.bmm(attention, value).squeeze(1)  #(N, value_size)
            
            inp = torch.cat((y, context), 1)    # (N, embed+value)
            h = inp
            for (l, rnn) in enumerate(self.rnns):
                h1, c1 = rnn(h, i, hiddens[l])
                hiddens[l] = (h1,c1)
                h = h1            
            outs.append(self.scoring(torch.cat((h,context),1)))
            query = self.fc1(h).unsqueeze(1)
        
        outs = torch.stack(outs, dim=0)  #(L2_padded, N, vocab_size)
        return outs, hiddens
        
        

### Trainer class

In [17]:
class AttentionTrainer:
    def __init__ (self, listener, speller, train_loader, val_loader, max_epochs=1, run_id='exp'):
        self.listener = listener.to(DEVICE)
        self.speller = speller.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.optimizer1 = torch.optim.Adam(self.speller.parameters(), lr=0.001, weight_decay=1e-5)
        self.optimizer2 = torch.optim.Adam(self.speller.parameters(), lr=0.001, weight_decay=1e-5)
        self.criterion = nn.CrossEntropyLoss(reduce='false').to(DEVICE)
        self.scheduler = ReduceLROnPlateau(self.optimizer2, factor = 0.1, patience = 3, mode = 'min')
    
    def train(self):
        self.listener.train()
        self.speller.train()
        self.epochs += 1
        epoch_loss = 0

        for batch_num, (inputs, targets) in enumerate(self.train_loader):
            epoch_loss += self.train_batch(inputs, targets)
       
        epoch_loss = epoch_loss / (batch_num + 1)
   
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss))
        self.train_losses.append(epoch_loss)
    
    def train_batch(self, inputs, targets):
        # listener generates key/value, s_{-1},_ = hidden
        key, value, hidden, seq_lens = self.listener(inputs)
        
        trans_lens = [len(t) for t in targets]
        targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 
        inp = targets[:-1,:]
        mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
        for i in range(len(trans_lens)):
            mask[trans_lens[i]-1:,i] = 0
        
        teacher=False
#         if (np.random.randn() > 0.9):
#             teacher = True
        outs,_ = self.speller(inputs=inp, audio_lens=seq_lens, key=key, value=value, teacher=teacher)
        loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
        loss = torch.sum(torch.mul(mask.view(-1), loss))   
        
        self.optimizer1.zero_grad()
        self.optimizer2.zero_grad()
        loss.backward()
        self.optimizer1.step()
        self.optimizer2.step()
        
        return loss.item() / len(trans_lens)
    
    def validate(self):
        self.listener.eval()
        self.speller.eval()
        val_loss = 0
        ls = 0
        dev_len = len(self.val_loader.dataset)
        preds = []
        trues = []
        with torch.no_grad():
            for batch_num, (inputs, targets) in enumerate(self.val_loader):
                key, value, _, seq_lens = self.listener(inputs)

                for t in targets:
                    trues.append(t.cpu().numpy())
                    
                trans_lens = [len(t) for t in targets]
                targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 
                inp = targets[:-1,:]
                mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
                for i in range(len(trans_lens)):
                    mask[trans_lens[i]-1:,i] = 0
                outs,_ = self.speller(inputs=inp, audio_lens=seq_lens, key=key, value=value)
                loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
                loss = torch.sum(torch.mul(mask.view(-1), loss))   
                val_loss += (loss.item() / len(trans_lens)
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded, vocab_size)
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
                for p in pred:
                    preds.append(p.cpu().numpy())
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded-1, vocab_size)
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
                for p in pred:
                    preds.append(p.cpu().numpy())
                
            for i in range(dev_len):
                pred_i = preds[i]
                true_i = trues[i][1:-1]   # trues include <sos> and <eos>
                if 0 in pred_i:
                    pred_i = pred_i[:pred_i.tolist().index(0)]

                pred = "".join(vocab_map[o] for o in pred_i)
                true = "".join(vocab_map[l] for l in true_i)
                if i % 100 == 0:
                    print('pred:', pred)
                    print('true:', true)

                ls += L.distance(pred, true)
                
            return val_loss / (batch_num + 1), ls / dev_len
    
    def save_model(self):
        path1 = os.path.join('attention', self.run_id, 'listener-{}.pkl'.format(self.epochs))
        path2 = os.path.join('attention', self.run_id, 'speller-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.listener.state_dict()}, path1)
        torch.save({'state_dict': self.speller.state_dict()}, path2)
        

In [8]:
BATCH_SIZE = 8
NUM_EPOCHS = 3

In [9]:
trainset = CharDataset(X_train[:80], Y_train[:80])
devset = CharDataset(X_dev[:40], Y_dev[:40])
train_loader = DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate)
dev_loader = DataLoader(devset, batch_size=BATCH_SIZE, collate_fn = collate)

In [10]:
run_id = str(int(time.time()))
if not os.path.exists('./attention'):
    os.mkdir('./attention')
os.mkdir('./attention/%s' % run_id)
print("Saving models, predictions, and generated words to ./attention/%s" % run_id)

Saving models, predictions, and generated words to ./attention/1542253115


In [18]:
listener = ListenerModel()
speller = SpellerModel(vocab_size=len(vocab_map))
# l = torch.load('attention/1542253115/listener-5.pkl')
# listener.load_state_dict(l['state_dict'])
# s = torch.load('attention/1542253115/speller-5.pkl')
# speller.load_state_dict(s['state_dict'])
# checkpoint = torch.load('experiments/1542219735/speller-13.pkl')
# speller.load_state_dict(checkpoint['state_dict'])
trainer = AttentionTrainer(listener=listener, speller=speller, 
                           train_loader=train_loader, val_loader=dev_loader, 
                           max_epochs=NUM_EPOCHS, run_id=run_id)

In [12]:
best_dist = 1e30

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    trainer.train()

    val_loss, dist = trainer.validate()
    print('Val loss: {}, Val Levenshtein distance: {}'.format(val_loss, dist))
    if dist < best_dist:
        best_dist = dist
        print("Saving model for epoch {}, with Levenshtein distance: {}".format(epoch+1, best_dist))
        trainer.save_model()
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

  


[TRAIN]  Epoch [1/5]   Loss: 5128.7017
Val loss: 4821.005181206598, Val Levenshtein distance: 73.00904159132007
Saving model for epoch 1, with Levenshtein distance: 73.00904159132007
Time elapsed:  00:05:02
[TRAIN]  Epoch [2/5]   Loss: 4888.3238
Val loss: 4871.135932074652, Val Levenshtein distance: 72.63200723327306
Saving model for epoch 2, with Levenshtein distance: 72.63200723327306
Time elapsed:  00:05:03
[TRAIN]  Epoch [3/5]   Loss: 4773.7500
Val loss: 4957.209526909723, Val Levenshtein distance: 72.91862567811935
Time elapsed:  00:05:02
[TRAIN]  Epoch [4/5]   Loss: 4689.9039
Val loss: 4992.079644097223, Val Levenshtein distance: 72.72784810126582
Time elapsed:  00:05:02
[TRAIN]  Epoch [5/5]   Loss: 4551.6683
Val loss: 5035.129069010417, Val Levenshtein distance: 73.17721518987342
Time elapsed:  00:05:02


In [19]:
trainer.validate()

  


pred: tnteady comma the srirmiceutical rarafacturers  rrsociateon aas betuisted tnsirmy sive yyphen yay yxerndivn oo the nompirtsoeriod
true: already comma the pharmaceutical manufacturers' association has requested a forty five hyphen day extension to the comment period comma saying it needs more time for consideration period
pred: tonsectiout punne ioa petg  af teare  ttetsirsng ttvm aonital mands 
true: connecticut joins the ranks of states sponsoring seed capital funds
pred: tor the first time in years the republicans also captured both houses of congress
true: at least one person was killed and more than two hundred and fifty people were injured in the blasts
pred: the satoestaam  hhadenional beaelopment oonhnicues uich as larus oroupsas oneeewe fhth thngumer  
true: the south carolina educational radio network has won national broadcasting awards
pred: tovl  arocesiakits ahre on tart on ers d to rvlow thph at estmrs ah tryticupate dn the parket sa manit ng the r soases ff t y doo

(5035.129069010417, 73.17721518987342)

In [14]:
# trainer.save_model()