In [1]:
import numpy as np
import time, os
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [2]:
X_train = np.load('all/train.npy', encoding='bytes')
Y_train = np.load('all/character/train_labels.npy')
X_dev = np.load('all/dev.npy', encoding='bytes')
Y_dev = np.load('all/character/dev_labels.npy')
vocab_map = np.load('all/character/vocab.npy')

### dataset class

In [3]:
class CharDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels
    def __getitem__(self,i):
        if self.labels != None:
            return torch.tensor(self.data[i]), torch.tensor(self.labels[i], dtype=torch.long)
        else:
            return torch.tensor(self.data[i])
    def __len__(self):
        return self.data.shape[0]


def collate(seq_list):
    """
    return: a batch sorted by decreasing order of length of sequences
    inputs: (L_padded, batch_size, 40)
    targets: list of targets
    """
    inputs,targets = zip(*seq_list)
    lens = [seq.shape[0] for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

#### Initial setting: listener layer:2, hidden: 128/direction;  speller layer: 1, hidden: 256
#### after that: listener layer:3, hidden: 256; speller: *, hidden: 512 

In [4]:
class ListenerModel(nn.Module):
    def __init__(self, hidden_size=256, key_size=128, value_size=128, embed_size=40, nlayers=3):
        super(ListenerModel, self).__init__()
        self.nlayers = nlayers
        self.hidden_size = hidden_size
        
        # todo: add conv layer
        
        self.rnns = nn.ModuleList([
            nn.LSTM(embed_size if i==0 else hidden_size*4, hidden_size, num_layers=1, bidirectional=True) for i in range(nlayers)
        ])
        
        self.fc1 = nn.Linear(hidden_size*4, key_size)
        self.fc2 = nn.Linear(key_size, value_size)
        self.weight_init()
    
    def weight_init(self):
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        self.fc1.bias.data.fill_(0.01)
        self.fc2.bias.data.fill_(0.01)
        
        
    def forward(self, inputs):
        """
        key: (L, N, key_size)
        value: (L, N, value_size)
        hidden: hidden state of last layer at t = L, (h_n, c_n); h_n, (num_direction,bs,hidden)
        """
        batch_size = len(inputs)
        lens = [len(s) for s in inputs] # actual lens of all sequences (already sorted) 
        # todo: add conv layer
        padded_input = rnn.pad_sequence(inputs).to(DEVICE)  # (L_padded, N, 40)
        out_packed = rnn.pack_padded_sequence(padded_input, lens)
        for l in self.rnns:
            out_packed, hidden = l(out_packed)
            out_padded, _ = rnn.pad_packed_sequence(out_packed)
            seq_len, batch_size, dim = out_padded.size()
            if seq_len % 2 == 1:
                out_padded = out_padded[:seq_len-1,:,:]
                seq_len, batch_size, dim = out_padded.size()
                lens[:] = [i-1 for i in lens]
            out_padded = out_padded.permute(1,0,2).contiguous().view(batch_size, seq_len // 2, dim*2).permute(1,0,2)
            lens[:] = [i // 2 for i in lens]
            out_packed = rnn.pack_padded_sequence(out_padded, lens)   # (L_padded, N, hidden_size)
        
        key = self.fc1(out_padded)
        value = self.fc2(key)
        return key, value, hidden

In [9]:
class SpellerModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=512, embed_size=256, key_size=128, nlayers=2):
        super(SpellerModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.nlayers = nlayers
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # projection os state s
        self.fc1 = nn.Linear(hidden_size, key_size)
        
        # 1st layer, input: cat(y_{i-1},context_{i-1}); h_0: s_{i-1}
        self.rnns = nn.ModuleList([
            nn.LSTMCell(embed_size+key_size if layer==0 else hidden_size, hidden_size) for layer in range(nlayers)
        ])
        
        self.scoring = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs, key, value, hidden_init=None):
        """
        inputs: (L2_padded, batch_size), L2_padded = padded transcript length 
        key: (L_padded, bs, key_size)
        value: (L_padded, bs, value_size)
        query_init: (bs, hidden)
        context: (batch_size, context_size)
        outs: (L2_padded, bs, vocab_size)
        hiddens: a list of (h_n, c_n), n = L2_padded
        """
        key = key.permute(1, 2, 0)
        value = value.permute(1, 0, 2)
#         print(inputs.size())
        
        embed = self.embed(inputs)   # (L2_padded, batch_size, embed_size)
        hiddens = [None] * self.nlayers
        for (i,h) in enumerate(hidden_init):
            hiddens[i] = h
        outs = []
        
        for y in embed: #(N, embed_size)
            # create context
            s,_ = hiddens[0]
            query = self.fc1(s).unsqueeze(1)   # (N, 1, key_size)
            # create context
            energy = torch.bmm(query, key)  #(N, 1, L_padded)
            attention = F.log_softmax(energy, 2)
            context = torch.bmm(attention, value).squeeze()  #(N, value_size)
            
            inp = torch.cat((y, context), 1)    # (N, embed+value)
            h = inp
            for (l, rnn) in enumerate(self.rnns):
                h1, c1 = rnn(h, hiddens[l])
                hiddens[l] = (h1,c1)
                h = h1

            outs.append(self.scoring(h))
        
        outs = torch.stack(outs, dim=0)  #(L2_padded, N, vocab_size)
        return outs, hiddens
        
        

### Trainer class

In [44]:
class AttentionTrainer:
    def __init__ (self, listener, speller, train_loader, val_loader, max_epochs=1, run_id='exp'):
        self.listener = listener.to(DEVICE)
        self.speller = speller.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.optimizer1 = torch.optim.Adam(self.speller.parameters(), lr=0.0005, weight_decay=1e-5)
        self.optimizer2 = torch.optim.Adam(self.speller.parameters(), lr=0.0005, weight_decay=1e-5)
        self.criterion = nn.CrossEntropyLoss(reduce='false').to(DEVICE)
        self.scheduler = ReduceLROnPlateau(self.optimizer2, factor = 0.1, patience = 3, mode = 'min')
    
    def train(self):
        self.listener.train()
        self.speller.train()
        self.epochs += 1
        epoch_loss = 0

        for batch_num, (inputs, targets) in enumerate(self.train_loader):
            epoch_loss += self.train_batch(inputs, targets)
       
        epoch_loss = epoch_loss / (batch_num + 1)
   
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss))
        self.train_losses.append(epoch_loss)
    
    def train_batch(self, inputs, targets):
        # listener generates key/value, s_{-1},_ = hidden
        key, value, hidden = self.listener(inputs)
        a,bs,hs = hidden[0].size()
        h_1 = hidden[0].permute(1,0,2).contiguous().view(-1,a*hs)
        c_1 = hidden[1].permute(1,0,2).contiguous().view(-1,a*hs)
        hidden_init = [(h_1, c_1)]
        
        lens = [len(t) for t in targets]
        targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 
        inp = targets[:-1,:]
        mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
        for i in range(len(lens)):
            mask[lens[i]-1:,i] = 0
        outs,_ = self.speller(inputs=inp, key=key, value=value, hidden_init=hidden_init)
        loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
        loss = torch.sum(torch.mul(mask.view(-1), loss))   
        
        self.optimizer1.zero_grad()
        self.optimizer2.zero_grad()
        loss.backward()
        self.optimizer1.step()
        self.optimizer2.step()
        
        return loss.item()
    
    def validate(self):
        self.listener.eval()
        self.speller.eval()
        val_loss = 0
        ls = 0
        dev_len = len(self.val_loader.dataset)
        preds = []
        trues = []
        with torch.no_grad():
            for batch_num, (inputs, targets) in enumerate(self.val_loader):
                key, value, hidden = self.listener(inputs)
                a,bs,hs = hidden[0].size()
                h_1 = hidden[0].permute(1,0,2).contiguous().view(-1,a*hs)
                c_1 = hidden[1].permute(1,0,2).contiguous().view(-1,a*hs)
                hidden_init = [(h_1, c_1)]
                
                for t in targets:
                    trues.append(t.cpu().numpy())
                    
                lens = [len(t) for t in targets]
                targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 
                inp = targets[:-1,:]
                mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
                for i in range(len(lens)):
                    mask[lens[i]-1:,i] = 0
                outs,_ = self.speller(inputs=inp, key=key, value=value, hidden_init=hidden_init)
                loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
                loss = torch.sum(torch.mul(mask.view(-1), loss))   
                val_loss += loss.item()
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded, vocab_size)
#                 print('probs0: ',probs[0])
#                 print(torch.max(probs[0],dim=1))
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
#                 print(pred.size())
                for p in pred:
                    preds.append(p.cpu().numpy())
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded-1, vocab_size)
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
                for p in pred:
                    preds.append(p.cpu().numpy())
                
            for i in range(dev_len):
                pred_i = preds[i]
                true_i = trues[i][1:-1]   # trues include <sos> and <eos>
                if 0 in pred_i:
                    pred_i = pred_i[:pred_i.tolist().index(0)]

                pred = "".join(vocab_map[o] for o in pred_i)
                true = "".join(vocab_map[l] for l in true_i)
                if i % 500 == 0:
                    print('pred:', pred)
                    print('true:', true)

                ls += L.distance(pred, true)
                
            return val_loss / (batch_num + 1), ls / dev_len
    
    def save_model(self):
        path1 = os.path.join('attention', self.run_id, 'listener-{}.pkl'.format(self.epochs))
        path2 = os.path.join('attention', self.run_id, 'speller-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.listener.state_dict()}, path1)
        torch.save({'state_dict': self.speller.state_dict()}, path2)
        

In [45]:
BATCH_SIZE = 256
NUM_EPOCHS = 5

In [46]:
trainset = CharDataset(X_train[6400:12800], Y_train[6400:12800])
devset = CharDataset(X_dev, Y_dev)
train_loader = DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate)
dev_loader = DataLoader(devset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate)

In [47]:
run_id = str(int(time.time()))
if not os.path.exists('./attention'):
    os.mkdir('./attention')
os.mkdir('./attention/%s' % run_id)
print("Saving models, predictions, and generated words to ./attention/%s" % run_id)

Saving models, predictions, and generated words to ./attention/1542229876


In [48]:
listener = ListenerModel()
speller = SpellerModel(vocab_size=len(vocab_map))
l = torch.load('attention/1542229337/listener-3.pkl')
listener.load_state_dict(l['state_dict'])
s = torch.load('attention/1542229337/speller-3.pkl')
speller.load_state_dict(s['state_dict'])
# checkpoint = torch.load('experiments/1542219735/speller-13.pkl')
# speller.load_state_dict(checkpoint['state_dict'])
trainer = AttentionTrainer(listener=listener, speller=speller, 
                           train_loader=train_loader, val_loader=dev_loader, 
                           max_epochs=NUM_EPOCHS, run_id=run_id)

In [49]:
best_dist = 1e30

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    trainer.train()

    val_loss, dist = trainer.validate()
    print('Val loss: {}, Val Levenshtein distance: {}'.format(val_loss, dist))
    if dist < best_dist:
        best_dist = dist
        print("Saving model for epoch {}, with Levenshtein distance: {}".format(epoch+1, best_dist))
        trainer.save_model()
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

  


[TRAIN]  Epoch [1/5]   Loss: 31497.1464
pred: thet tor n n nen aen nn teteae tn  ne  then aen e  thr aon   torn  aon ao  ne tn e   e   aoreantnnieneaoreontneaen tn  aoreontoen n  toreontntnne tor n  aoreontn ao  n   aorn ar an toenen  toren   aor nn 
true: this paradoxical notion simply assumes that higher tax rates would not reduce investment comma output comma employment comma profits comma equity values comma or related sources of private savings period
pred: the tereeaeneeae tn   ahrteen ne to   een  
true: washington apparently got the word period
pred: th e aoreenteen ae   aeen  ahen ahre tor   nn 
true: he's shaking up the management of french companies
Val loss: 23508.386328125, Val Levenshtein distance: 79.29204339963833
Saving model for epoch 1, with Levenshtein distance: 79.29204339963833
Time elapsed:  00:01:28
[TRAIN]  Epoch [2/5]   Loss: 30121.9456
pred: tn tn  toe n   che tomee   an then tere  ne  toe ne tomeentere  ne  tn  ene co e te ee  to  n  nd aoe e to   
true: as 