In [14]:
import numpy as np
import time, os
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cpu'

In [3]:
X_train = np.load('all/train.npy', encoding='bytes')
Y_train = np.load('all/character/train_labels.npy')
X_dev = np.load('all/dev.npy', encoding='bytes')
Y_dev = np.load('all/character/dev_labels.npy')
vocab_map = np.load('all/character/vocab.npy')

### dataset class

In [4]:
class CharDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels
    def __getitem__(self,i):
        if self.labels != None:
            return torch.tensor(self.data[i]), torch.tensor(self.labels[i], dtype=torch.long)
        else:
            return torch.tensor(self.data[i])
    def __len__(self):
        return self.data.shape[0]


def collate(seq_list):
    """
    return: a batch sorted by decreasing order of length of sequences
    inputs: (L_padded, batch_size, 40)
    targets: list of targets
    """
    inputs,targets = zip(*seq_list)
    lens = [seq.shape[0] for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

#### Initial setting: listener layer:2, hidden: 128/direction;  speller layer: 1, hidden: 256
#### after that: listener layer:3, hidden: 256; speller: *, hidden: 512 

In [54]:
class ListenerModel(nn.Module):
    def __init__(self, hidden_size=256, key_size=128, value_size=128, embed_size=40, nlayers=3):
        super(ListenerModel, self).__init__()
        self.nlayers = nlayers
        self.hidden_size = hidden_size
        
        # todo: add conv layer
        
        self.rnns = nn.ModuleList([
            nn.LSTM(embed_size if i==0 else hidden_size*4, hidden_size, num_layers=1, bidirectional=True) for i in range(nlayers)
        ])
        
        self.fc1 = nn.Linear(hidden_size*4, key_size)
        self.fc2 = nn.Linear(key_size, value_size)
        
    def forward(self, inputs):
        """
        key: (L, N, key_size)
        value: (L, N, value_size)
        hidden: hidden state of last layer at t = L, (h_n, c_n); h_n, (num_direction,bs,hidden)
        """
        batch_size = len(inputs)
        lens = [len(s) for s in inputs] # actual lens of all sequences (already sorted) 
        # todo: add conv layer
        padded_input = rnn.pad_sequence(inputs).to(DEVICE)  # (L_padded, N, 40)
        out_packed = rnn.pack_padded_sequence(padded_input, lens)
        for l in self.rnns:
            out_packed, hidden = l(out_packed)
            out_padded, _ = rnn.pad_packed_sequence(out_packed)
            seq_len, batch_size, dim = out_padded.size()
            if seq_len % 2 == 1:
                out_padded = out_padded[:seq_len-1,:,:]
                seq_len, batch_size, dim = out_padded.size()
                lens[:] = [i-1 for i in lens]
            print(out_padded.size())
            out_padded = out_padded.permute(1,0,2).contiguous().view(batch_size, seq_len // 2, dim*2).permute(1,0,2)
            lens[:] = [i // 2 for i in lens]
            out_packed = rnn.pack_padded_sequence(out_padded, lens)   # (L_padded, N, hidden_size)
        
        key = self.fc1(out_padded)
        value = self.fc2(key)
        return key, value, hidden

In [115]:
class SpellerModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=512, embed_size=256, key_size=128, nlayers=2):
        super(SpellerModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.nlayers = nlayers
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # projection os state s
        self.fc1 = nn.Linear(hidden_size, key_size)
        
        # 1st layer, input: cat(y_{i-1},context_{i-1}); h_0: s_{i-1}
        self.rnns = nn.ModuleList([
            nn.LSTMCell(embed_size+key_size if layer==0 else hidden_size, hidden_size) for layer in range(nlayers)
        ])
        
        self.scoring = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs, key, value, hidden_init):
        """
        inputs: (L2_padded, batch_size), L2_padded = padded transcript length 
        key: (L_padded, bs, key_size)
        value: (L_padded, bs, value_size)
        query_init: (bs, hidden)
        context: (batch_size, context_size)
        outs: (L2_padded, bs, vocab_size)
        hiddens: a list of (h_n, c_n), n = L2_padded
        """
        key = key.permute(1, 2, 0)
        value = value.permute(1, 0, 2)
        print(inputs.size())
        
        embed = self.embed(inputs)   # (L2_padded, batch_size, embed_size)
        hiddens = [None] * self.nlayers
        for (i,h) in enumerate(hidden_init):
            hiddens[i] = h
        outs = []
        
        for y in embed: #(N, embed_size)
            # create context
            s,_ = hiddens[0]
            query = self.fc1(s).unsqueeze(1)   # (N, 1, key_size)
            # create context
            energy = torch.bmm(query, key)  #(N, 1, L_padded)
            attention = F.log_softmax(energy, 2)
            context = torch.bmm(attention, value).squeeze()  #(N, value_size)
            inp = torch.cat((y, context), 1)    # (N, embed+value)
            h = inp
            for (l, rnn) in enumerate(self.rnns):
                h1, c1 = rnn(h, hiddens[l])
                hiddens[l] = (h1,c1)
                h = h1

            outs.append(self.scoring(h))
        
        outs = torch.stack(outs, dim=0)  #(L2_padded, N, vocab_size)
        return outs, hiddens
        
        

### Trainer class

In [183]:
class AttentionTrainer:
    def __init__ (self, listener, speller, train_loader, val_loader, max_epochs=1, run_id='exp'):
        self.listener = listener.to(DEVICE)
        self.speller = speller.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.optimizer1 = torch.optim.ASGD(self.listener.parameters(), lr=0.01, weight_decay=0.02)
        self.optimizer2 = torch.optim.Adam(self.speller.parameters(), lr=0.001, weight_decay=0.0001)
        self.criterion = nn.CrossEntropyLoss(ignore_index=len(vocab_map)).to(DEVICE)
    
    def train(self):
        self.listener.train()
        self.speller.train()
        self.epochs += 1
        epoch_loss = 0

        for batch_num, (inputs, targets) in enumerate(self.train_loader):
            epoch_loss += self.train_batch(inputs, targets)
       
        epoch_loss = epoch_loss / (batch_num + 1)
   
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss))
        self.train_losses.append(epoch_loss)
    
    def train_batch(self, inputs, targets):
        # listener generates key/value, s_{-1},_ = hidden
        key, value, hidden = self.listener(inputs)
        a,bs,hs = hidden[0].size()
        h_1 = hidden[0].permute(1,0,2).contiguous().view(-1,a*hs)
        c_1 = hidden[1].permute(1,0,2).contiguous().view(-1,a*hs)
        hidden_init = [(h_1, c_1)]
        print('hidden init size',hidden_init[0][0].size())
        
        targets = rnn.pad_sequence(targets, padding_value=len(vocab_map)).to(DEVICE) # (L2_padded, bs)
        outs,_ = self.speller(inputs=targets[:-1,:], key=key, value=value, hidden_init=hidden_init)
        print('outs shape',outs.size())
        print('target shape', targets.size())
        
        loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:].view(-1))      
        self.optimizer1.zero_grad()
        self.optimizer2.zero_grad()
        loss.backward()
        self.optimizer1.step()
        self.optimizer2.step()
        
        return loss.item()
    
    def validate(self):
        self.listener.eval()
        self.speller.eval()
        val_loss = 0
        ls = 0
        dev_len = len(self.val_loader.dataset)
        preds = []
        trues = []
        with torch.no_grad():
            for batch_num, (inputs, targets) in enumerate(self.val_loader):
                key, value, hidden = self.listener(inputs)
                a,bs,hs = hidden[0].size()
                h_1 = hidden[0].permute(1,0,2).contiguous().view(-1,a*hs)
                c_1 = hidden[1].permute(1,0,2).contiguous().view(-1,a*hs)
                hidden_init = [(h_1, c_1)]
                
                for t in targets:
                    trues.append(t.cpu().numpy())
                    
                targets = rnn.pad_sequence(targets, padding_value=len(vocab_map)).to(DEVICE) # (L2_padded, bs)

                outs,_ = self.speller(inputs=targets[:-1,:], key=key, value=value, hidden_init=hidden_init)
            
                loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:].view(-1))
                val_loss += loss.item()
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded, vocab_size)
                print('probs0: ',probs[0])
                print(torch.max(probs[0],dim=1))
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
                print(pred.size())
                for p in pred:
                    preds.append(p.cpu().numpy())
                
            for i in range(dev_len):
                pred_i = preds[i]
                true_i = trues[i][1:-1]
#                 print(pred_i)
                if 0 in pred_i:
                    pred_i = pred_i[:pred_i.index(0)]

                pred = "".join(vocab_map[o] for o in pred_i)
                true = "".join(vocab_map[l] for l in true_i)
                print('pred:', pred)
                print('true:', true)

                ls += L.distance(pred, true)
                
            return val_loss, ls
    
    def save_model(self):
        path1 = os.path.join('experiments', self.run_id, 'listenr-{}.pkl'.format(self.epochs))
        path2 = os.path.join('experiments', self.run_id, 'speller-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.listener.state_dict()}, path1)
        torch.save({'state_dict': self.speller.state_dict()}, path2)
        

In [174]:
BATCH_SIZE = 2
NUM_EPOCHS = 1

In [175]:
trainset = CharDataset(X_train[:4], Y_train[:4])
devset = CharDataset(X_dev[:2], Y_dev[:2])
train_loader = DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate)
dev_loader = DataLoader(devset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate)

In [15]:
run_id = str(int(time.time()))
if not os.path.exists('./experiments'):
    os.mkdir('./experiments')
os.mkdir('./experiments/%s' % run_id)
print("Saving models, predictions, and generated words to ./experiments/%s" % run_id)

Saving models, predictions, and generated words to ./experiments/1542142062


In [184]:
listener = ListenerModel()
speller = SpellerModel(vocab_size=len(vocab_map)+1)
trainer = AttentionTrainer(listener=listener, speller=speller, 
                           train_loader=train_loader, val_loader=dev_loader, 
                           max_epochs=NUM_EPOCHS, run_id=run_id)


In [185]:
best_dist = 1e30

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    trainer.train()

    val_loss, dist = trainer.validate()
    print('Val loss: {}, Val Levenshtein distance: {}'.format(val_loss, dist))
    if dist < best_dist:
        best_dist = dist
        print("Saving model for epoch {}, with Levenshtein distance: {}".format(epoch+1, best_dist))
        trainer.save_model()
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

  


torch.Size([476, 2, 512])
torch.Size([238, 2, 512])
torch.Size([118, 2, 512])
hidden init size torch.Size([2, 512])
torch.Size([78, 2])
outs shape torch.Size([78, 2, 33])
target shape torch.Size([79, 2])
torch.Size([524, 2, 512])
torch.Size([262, 2, 512])
torch.Size([130, 2, 512])
hidden init size torch.Size([2, 512])
torch.Size([79, 2])
outs shape torch.Size([79, 2, 33])
target shape torch.Size([80, 2])
[TRAIN]  Epoch [1/1]   Loss: 3.3841
torch.Size([466, 2, 512])
torch.Size([232, 2, 512])
torch.Size([116, 2, 512])
torch.Size([75, 2])
probs0:  tensor([[0.0272, 0.0347, 0.0285,  ..., 0.0278, 0.0265, 0.0270],
        [0.0233, 0.0385, 0.0248,  ..., 0.0239, 0.0220, 0.0231],
        [0.0204, 0.0407, 0.0221,  ..., 0.0209, 0.0187, 0.0203],
        ...,
        [0.0157, 0.0450, 0.0166,  ..., 0.0158, 0.0129, 0.0146],
        [0.0157, 0.0450, 0.0166,  ..., 0.0158, 0.0130, 0.0146],
        [0.0157, 0.0450, 0.0165,  ..., 0.0158, 0.0130, 0.0147]])
(tensor([0.0521, 0.0851, 0.1179, 0.1425, 0.1604, 0.

In [182]:
print(vocab_map[30])

.
