In [34]:
import numpy as np
import time, os, random
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import Levenshtein as L
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [44]:
BATCH_SIZE = 512
NUM_EPOCHS = 10
context = torch.randn(BATCH_SIZE, 128).to(DEVICE)

In [36]:
X_train = np.load('all/train.npy', encoding='bytes')
Y_train = np.load('all/character/train_labels.npy')
X_dev = np.load('all/dev.npy', encoding='bytes')
Y_dev = np.load('all/character/dev_labels.npy')
vocab_map = np.load('all/character/vocab.npy')

### dataset class

In [37]:
class CharDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels
    def __getitem__(self,i):
        if self.labels != None:
            return torch.tensor(self.data[i]), torch.tensor(self.labels[i], dtype=torch.long)
        else:
            return torch.tensor(self.data[i])
    def __len__(self):
        return self.data.shape[0]


def collate(seq_list):
    """
    return: a batch sorted by decreasing order of length of sequences
    inputs: (L_padded, batch_size, 40)
    targets: list of targets
    """
    inputs,targets = zip(*seq_list)
    lens = [seq.shape[0] for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    targets = [targets[i] for i in seq_order]
    return inputs,targets

#### Initial setting: listener layer:2, hidden: 128/direction;  speller layer: 1, hidden: 256
#### after that: listener layer:3, hidden: 256; speller: *, hidden: 512 

In [38]:
class SpellerModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=512, embed_size=256, key_size=128, nlayers=2):
        super(SpellerModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.nlayers = nlayers
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # projection os state s
        self.fc1 = nn.Linear(hidden_size, key_size)
        
        # 1st layer, input: cat(y_{i-1},context_{i-1}); h_0: s_{i-1}
        self.rnns = nn.ModuleList([
            nn.LSTMCell(embed_size+key_size if layer==0 else hidden_size, hidden_size) for layer in range(nlayers)
        ])
        
        self.scoring = nn.Linear(hidden_size, vocab_size)
        self.weight_init()
        
    def weight_init(self):
        torch.nn.init.xavier_uniform(self.fc1.weight)
        torch.nn.init.xavier_uniform(self.scoring.weight)
        self.fc1.bias.data.fill_(0.01)
        self.scoring.bias.data.fill_(0.01)
        
    
    def forward(self, inputs, hidden_init=None):
        """
        inputs: (L2_padded, batch_size), L2_padded = padded transcript length 
        key: (L_padded, bs, key_size)
        value: (L_padded, bs, value_size)
        query_init: (bs, hidden)
        context: (batch_size, context_size)
        outs: (L2_padded, bs, vocab_size)
        hiddens: a list of (h_n, c_n), n = L2_padded
        """
        inputs = inputs.to(DEVICE)
        embed = self.embed(inputs)   # (L2_padded, batch_size, embed_size)
        hiddens = [None] * self.nlayers

        # try teacher forcing 
#         if slef.training:
#             p = np.random.random_sample()
        
        outs = []
        for y in embed: #(N, embed_size)
#             if slef.training:
#                 if(np.random.random_sample()>0.9):
#                     y = 
                
            inp = torch.cat((y, context), dim=1)    # (N, embed+value)
            h = inp
            for (l, rnn) in enumerate(self.rnns):
                h1, c1 = rnn(h, hiddens[l])
                hiddens[l] = (h1,c1)
                h = h1

            outs.append(self.scoring(h))
        
        outs = torch.stack(outs, dim=0)  #(L2_padded, N, vocab_size)
        return outs, hiddens
        
        

### Trainer class

In [39]:
class AttentionTrainer:
    def __init__ (self, speller, train_loader, val_loader, max_epochs=1, run_id='exp'):
        self.speller = speller.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_losses = []
        self.val_losses = []
        self.epochs = 0
        self.max_epochs = max_epochs
        self.run_id = run_id
        
        self.optimizer2 = torch.optim.Adam(self.speller.parameters(), lr=0.0001, weight_decay=1e-5)
        self.criterion = nn.CrossEntropyLoss(ignore_index=len(vocab_map)).to(DEVICE)
        self.scheduler = ReduceLROnPlateau(self.optimizer2, factor = 0.1, patience = 3, mode = 'min')
    
    def train(self):
        self.speller.train()
        self.epochs += 1
        epoch_loss = 0

        for batch_num, (inputs, targets) in enumerate(self.train_loader):
            epoch_loss += self.train_batch(inputs, targets)
       
        epoch_loss = epoch_loss / (batch_num + 1)
   
        print('[TRAIN]  Epoch [%d/%d]   Loss: %.4f'
                      % (self.epochs, self.max_epochs, epoch_loss))
        self.train_losses.append(epoch_loss)
    
    def train_batch(self, inputs, targets):
        targets = rnn.pad_sequence(targets, padding_value=len(vocab_map)).to(DEVICE) # (L2_padded, bs)     
        outs,_ = self.speller(inputs=targets)      
        loss = self.criterion(outs[:-1,:].view(-1, outs.size(2)), targets[1:].view(-1))    
        
#         inputs = inputs.to(DEVICE)
#         targets = targets.to(DEVICE)
#         outputs = self.speller(inputs=inputs) # 3D
#         loss = self.criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs
        
        self.optimizer2.zero_grad()
        loss.backward()
        self.optimizer2.step()
        
        return loss.item()
    
    def validate(self):
        self.speller.eval()
        val_loss = 0
        ls = 0
        dev_len = BATCH_SIZE * (len(self.val_loader.dataset) // BATCH_SIZE)
        preds = []
        trues = []
        with torch.no_grad():
            for batch_num, (inputs, targets) in enumerate(self.val_loader):
                for t in targets:
                    trues.append(t.cpu().numpy())
                
                targets = rnn.pad_sequence(targets, padding_value=len(vocab_map)).to(DEVICE) # (L2_padded, bs)     
                outs,_ = self.speller(inputs=targets)  
                loss = self.criterion(outs[:-1,:].view(-1, outs.size(2)), targets[1:].view(-1))    
#                 inputs = inputs.to(DEVICE)
#                 targets = targets.to(DEVICE)
#                 outputs = self.speller(inputs) # 3D
#                 loss = self.criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs
                
                val_loss += loss.item()
                
                probs = F.softmax(outs.permute(1,0,2), dim=2) # (N, L2_padded, vocab_size)
                pred = torch.argmax(probs, dim=2) # (N, L2_padded)
                for p in pred:
                    preds.append(p.cpu().numpy())
                
            for i in range(dev_len):
                pred_i = preds[i]
                true_i = trues[i][1:-1]
                if 0 in pred_i:
                    pred_i = pred_i[:pred_i.tolist().index(0)]

                pred = "".join(vocab_map[o] for o in pred_i)
                true = "".join(vocab_map[l] for l in true_i)
#                 if i % 2000 == 0:
#                     print('pred:', pred)
#                     print('true:', true)

                ls += L.distance(pred, true)
                
            return val_loss, ls / dev_len
    
    def save_model(self):
        path2 = os.path.join('experiments', self.run_id, 'speller-{}.pkl'.format(self.epochs))
        torch.save({'state_dict': self.speller.state_dict()}, path2)
        

In [40]:
trainset = CharDataset(X_train, Y_train)
devset = CharDataset(X_dev, Y_dev)


# random.shuffle(X_train)
# sequence = np.concatenate(([a for a in X_train]))
# trainset = TextDataset(sequence)
# s = np.concatenate(([a for a in X_dev]))
# devset = TextDataset(s)

train_loader = DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, collate_fn = collate, drop_last=True)
dev_loader = DataLoader(devset, batch_size=BATCH_SIZE, collate_fn = collate, drop_last=True)

In [41]:
run_id = str(int(time.time()))
if not os.path.exists('./experiments'):
    os.mkdir('./experiments')
os.mkdir('./experiments/%s' % run_id)
print("Saving models, predictions, and generated words to ./experiments/%s" % run_id)

Saving models, predictions, and generated words to ./experiments/1542170607


In [45]:
speller = SpellerModel(vocab_size=len(vocab_map)+1)
checkpoint = torch.load('experiments/1542168372/speller-4.pkl')
speller.load_state_dict(checkpoint['state_dict'])
trainer = AttentionTrainer(speller=speller, 
                           train_loader=train_loader, val_loader=dev_loader, 
                           max_epochs=NUM_EPOCHS, run_id=run_id)




In [None]:
best_dist = 1e30

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    trainer.train()

    val_loss, dist = trainer.validate()
    trainer.scheduler.step(val_loss)
    print('Val loss: {}, Val Levenshtein distance: {}'.format(val_loss, dist))
    if dist < best_dist:
        best_dist = dist
        print("Saving model for epoch {}, with Levenshtein distance: {}".format(epoch+1, best_dist))
        trainer.save_model()
    elapsed_time = time.time() - start_time
    print('Time elapsed: ', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

  


[TRAIN]  Epoch [1/10]   Loss: 0.9246
pred: t  afe haoake  ohe forrent cffthes pile hmes goedte aomma aame hmes graeayd tevl oyphen tesurd ng saale cagt<unk>paren heses an s antiously roneh ng teght paren comma tne hndxetilly phesgs tn teclh cy aeipn ng aeriod
true: as one breasts the current of this sometimes creamy comma sometimes awkward self hyphen regarding style left<unk>parentheses it's obviously catching right<unk>paren comma one inevitably thinks of death by drowning period
Val loss: 1.6798783540725708, Val Levenshtein distance: 70.8681640625
Saving model for epoch 1, with Levenshtein distance: 70.8681640625
Time elapsed:  00:00:39
[TRAIN]  Epoch [2/10]   Loss: 0.9165
pred: t  afe haoake  ohe forrent cffthes cale hmes goedte aomma aame hmes g aeayd aevl oyphen seaurd ng taale cagt<unk>paren heses an s antiously woneh ng teght paren pomma tne hndxetille phesgs tn teclh cy teipn ng aeriod
true: as one breasts the current of this sometimes creamy comma sometimes awkward self hyphe