In [1]:
import numpy as np
import time, os, random
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [28]:
BATCH_SIZE = 512
NUM_EPOCHS = 10
context = torch.randn(BATCH_SIZE, 128).to(DEVICE)

In [3]:
X_train = np.load('all/train.npy', encoding='bytes')
Y_train = np.load('all/character/train_labels.npy')
X_dev = np.load('all/dev.npy', encoding='bytes')
Y_dev = np.load('all/character/dev_labels.npy')
vocab_map = np.load('all/character/vocab.npy')

### dataset class

In [4]:
class CharDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels
    def __getitem__(self,i):
        if self.labels != None:
            return torch.tensor(self.data[i]), torch.tensor(self.labels[i], dtype=torch.long)
        else:
            return torch.tensor(self.data[i])
    def __len__(self):
        return self.data.shape[0]


def collate(seq_list):
    """
    return: a batch sorted by decreasing order of length of sequences
    inputs: (L_padded, batch_size, 40)
    targets: list of targets
    """
    inputs,targets = zip(*seq_list)
    lens = [seq.shape[0] for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

In [5]:
class LSTMCell_(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell_, self).__init__()
        self.lstmcell = nn.LSTMCell(input_size, hidden_size, bias)
        self.h_0 = nn.Parameter(torch.randn(1,hidden_size).to(DEVICE))
        self.c_0 = nn.Parameter(torch.randn(1,hidden_size).to(DEVICE))
        self.parameters = nn.ParameterList([self.h_0,self.c_0])
    
    def forward(self, inputs, t, hidden_states):
        if t == 0:
            hidden_states = (self.h_0.expand(inputs.size(0),-1), self.c_0.expand(inputs.size(0),-1))
        hidden_state = self.lstmcell(inputs, hidden_states)
        return hidden_state
        

#### Initial setting: listener layer:2, hidden: 128/direction;  speller layer: 1, hidden: 256
#### after that: listener layer:3, hidden: 256; speller: *, hidden: 512 

In [63]:
class SpellerModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=512, embed_size=256, key_size=128, nlayers=2):
        super(SpellerModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.nlayers = nlayers
        self.vocab_size = vocab_size
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # projection os state s
        self.fc1 = nn.Linear(hidden_size, key_size)
        
        # 1st layer, input: cat(y_{i-1},context_{i-1}); h_0: s_{i-1}
        self.rnns = nn.ModuleList([
            LSTMCell_(embed_size+key_size if layer==0 else hidden_size, hidden_size) for layer in range(nlayers)
        ])
#         print(self.rnns)
        
        # todo: cat(output, context)
        self.scoring = nn.Linear(hidden_size+key_size, vocab_size)
    
    def forward(self, inputs, teacher=False, hidden_init=None):
        """
        inputs: (L2_padded, batch_size), L2_padded = padded transcript length 
        key: (L_padded, bs, key_size)
        value: (L_padded, bs, value_size)
        query_init: (bs, hidden)
        context: (batch_size, context_size)
        outs: (L2_padded, bs, vocab_size)
        hiddens: a list of (h_n, c_n), n = L2_padded
        """
        L2_padded,N = inputs.size()
    
        embed = self.embed(inputs)   # (L2_padded, batch_size, embed_size)
        hiddens = [None] * self.nlayers
        outs = []
        
        for i in range(inputs.size(0)): 
            y = embed[i]
            inp = torch.cat((y, context), 1)    # (N, embed+value)
            h = inp
            for (l, rnn) in enumerate(self.rnns):
                h1, c1 = rnn(h, i, hiddens[l])
                hiddens[l] = (h1,c1)
                h = h1
            
            outs.append(self.scoring(torch.cat((h,context),1)))
            
        outs = torch.stack(outs, dim=0)  #(L2_padded, N, vocab_size)
        return outs, hiddens
        
        

### Trainer class

In [67]:
class AttentionTrainer:
    def __init__ (self, speller, train_loader, val_loader, max_epochs=1, run_id='exp'):
        self.speller = speller.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.optimizer2 = torch.optim.Adam(self.speller.parameters(), lr=0.0001, weight_decay=1e-5)
        self.criterion = nn.CrossEntropyLoss(reduce='false').to(DEVICE)
        self.scheduler = ReduceLROnPlateau(self.optimizer2, factor = 0.1, patience = 3, mode = 'min')
    
    def train(self):
        self.speller.train()
        self.epochs += 1
        epoch_loss = 0

        for batch_num, (inputs, targets) in enumerate(self.train_loader):
            epoch_loss += self.train_batch(inputs, targets)
       
        epoch_loss = epoch_loss / (batch_num + 1)
   
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss))
        self.train_losses.append(epoch_loss)
    
    def train_batch(self, inputs, targets):
        trans_lens = [len(t) for t in targets]
        targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 
        
        inp = targets[:-1,:]
        mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
        for i in range(len(trans_lens)):
            mask[trans_lens[i]-1:,i] = 0
        
        outs,_ = self.speller(inputs=inp)        
        
        loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
        loss = torch.sum(torch.mul(mask.view(-1), loss))
        self.optimizer2.zero_grad()
        loss.backward()
        self.optimizer2.step()
        
        return loss.item() / len(trans_lens)
    
    def validate(self):
        self.speller.eval()
        val_loss = 0
        ls = 0
        dev_len = BATCH_SIZE * (len(self.val_loader.dataset) // BATCH_SIZE)
        preds = []
        trues = []
        with torch.no_grad():
            for batch_num, (inputs, targets) in enumerate(self.val_loader):
                for t in targets:
                    trues.append(t.cpu().numpy())
                
                trans_lens = [len(t) for t in targets]
                targets = rnn.pad_sequence(targets, padding_value=0).to(DEVICE) # (L2_padded, bs) 

                inp = targets[:-1,:]
                mask = torch.ones(inp.size()).to(DEVICE)  # (L2_padded, bs)
                for i in range(len(trans_lens)):
                    mask[trans_lens[i]-1:,i] = 0
        
                outs,_ = self.speller(inputs=inp)                      
                loss = self.criterion(outs.view(-1, outs.size(2)), targets[1:,:].view(-1))    
                loss = torch.sum(torch.mul(mask.view(-1), loss))           
                val_loss += (loss.item() / len(trans_lens))
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded-1, vocab_size)
                pred = torch.argmax(probs, dim=2) # (N, L2_padded-1)
                for p in pred:
                    preds.append(p.cpu().numpy())
                
            # compute L distance
            for i in range(dev_len):
                pred_i = preds[i]
                true_i = trues[i][1:-1]   # exclude <sos> and <eos>
                if 0 in pred_i:
                    pred_i = pred_i[:pred_i.tolist().index(0)]

                pred = "".join(vocab_map[o] for o in pred_i)
                true = "".join(vocab_map[l] for l in true_i)
                if i % 400 == 0:
                    print('pred:', pred)
                    print('true:', true)
                ls += L.distance(pred, true)
                
            return val_loss / (batch_num + 1), ls / dev_len
    
    def save_model(self):
        path2 = os.path.join('experiments', self.run_id, 'speller-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.speller.state_dict()}, path2)
        

In [49]:
trainset = CharDataset(X_train, Y_train)
devset = CharDataset(X_dev, Y_dev)

train_loader = DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate, drop_last=True)
dev_loader = DataLoader(devset, batch_size=BATCH_SIZE, collate_fn = collate, drop_last=True)

In [50]:
run_id = str(int(time.time()))
if not os.path.exists('./experiments'):
    os.mkdir('./experiments')
os.mkdir('./experiments/%s' % run_id)
print("Saving models, predictions, and generated words to ./experiments/%s" % run_id)

Saving models, predictions, and generated words to ./experiments/1542249795


In [68]:
speller = SpellerModel(vocab_size=len(vocab_map))
checkpoint = torch.load('experiments/1542249216/speller-10.pkl')
speller.load_state_dict(checkpoint['state_dict'])
trainer = AttentionTrainer(speller=speller, 
                           train_loader=train_loader, val_loader=dev_loader, 
                           max_epochs=NUM_EPOCHS, run_id=run_id)


In [69]:
best_dist = 1e30

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    trainer.train()

    val_loss, dist = trainer.validate()
    trainer.scheduler.step(val_loss)
    print('Val loss: {}, Val Levenshtein distance: {}'.format(val_loss, dist))
    if dist < best_dist:
        best_dist = dist
        print("Saving model for epoch {}, with Levenshtein distance: {}".format(epoch+1, best_dist))
        trainer.save_model()
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

  


ValueError: Expected input batch_size (110080) to match target batch_size (109568).

In [33]:
trainer.save_model()