In [1]:
# We're using two different implementations the older implementation 
# had the option to switch from MultiHeaded attention to the standard
# one
from main import train

In [4]:
import os
import string
import numpy as np

import torch.optim as optim

class Task(object):

    def __init__(self, max_len=10, vocab_size=3):
        super(Task, self).__init__()
        self.max_len = max_len
        self.vocab_size = vocab_size
        assert self.vocab_size <= 26, "vocab_size needs to be <= 26 since we are using letters to prettify LOL"

    def next_batch(self, batchsize=100, signal=None):
        # np.random.seed(69)
        if signal is not None:
            signal = string.ascii_uppercase.index(signal)
            signal = np.eye(self.vocab_size)[np.ones((batchsize, 1), dtype=int) * signal]
        else:
            signal = np.eye(self.vocab_size)[np.random.choice(np.arange(self.vocab_size), [batchsize, 3])]
        seq = np.eye(self.vocab_size)[np.random.choice(np.arange(self.vocab_size), [batchsize, self.max_len])]
        x = np.concatenate((signal, seq), axis=1)
        y = np.eye(self.max_len + 1)[np.sum(np.expand_dims(np.argmax(signal,axis=2),axis=-1) == np.expand_dims(np.argmax(seq, axis=2), axis=1), axis=2)]
        return x, y

    def prettify(self, samples):
        samples = samples.reshape(-1, self.max_len + 3, self.vocab_size)
        idx = np.expand_dims(np.argmax(samples, axis=2), axis=2)
        dictionary = np.array(list(string.ascii_uppercase))
        return dictionary[idx]

In [5]:
task = Task()
samples, labels = task.next_batch()

In [6]:
samples, labels = task.next_batch()

In [6]:
task.prettify(samples[0])

array([[['B'],
        ['C'],
        ['A'],
        ['A'],
        ['A'],
        ['A'],
        ['B'],
        ['B'],
        ['C'],
        ['B'],
        ['B'],
        ['B'],
        ['A']]], dtype='<U1')

In [7]:
np.argmax(labels[0], axis=1)

array([5, 1, 4])

In [8]:
# (batchsize, max_len + 3, vocab_size)
samples.shape

(100, 13, 3)

In [9]:
samples[0]

array([[0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.]])

In [10]:
# (batchsize, 3, max_len + 1)
labels.shape

(100, 3, 11)

In [11]:
labels[0]

array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

# Without Multi-head Attention

In [13]:
train(print_every=50, use_multihead=False)

Iteration 50 - Loss 1.7411472797393799
Iteration 100 - Loss 1.769609808921814
Iteration 150 - Loss 1.6528491973876953
Iteration 200 - Loss 1.6441739797592163
Iteration 250 - Loss 1.4802706241607666
Iteration 300 - Loss 1.4717421531677246
Iteration 350 - Loss 1.31699538230896
Iteration 400 - Loss 1.028668761253357
Iteration 450 - Loss 0.8676766157150269
Iteration 500 - Loss 0.8158026337623596
Iteration 550 - Loss 0.732761561870575
Iteration 600 - Loss 0.6769591569900513
Iteration 650 - Loss 0.7156023383140564
Iteration 700 - Loss 0.6905409097671509
Iteration 750 - Loss 0.7155475616455078
Iteration 800 - Loss 0.6389250755310059
Iteration 850 - Loss 0.636111319065094
Iteration 900 - Loss 0.7251088619232178
Iteration 950 - Loss 0.6552238464355469
Iteration 1000 - Loss 0.6814892292022705
Iteration 1050 - Loss 0.623063325881958
Iteration 1100 - Loss 0.6318008899688721
Iteration 1150 - Loss 0.6472669243812561
Iteration 1200 - Loss 0.653508186340332
Iteration 1250 - Loss 0.6918123364448547
Ite

# With Multi-head Attention

In [7]:
import sys
sys.path.append("../")

from all_models import *

In [10]:
def train(max_len=10,
          vocab_size=3,
          hidden=64,
          pos_enc=True,
          enc_layers=2,
          use_multihead=True,
          heads=4,
          batchsize=100,
          steps=4000,
          print_every=50,
          savepath='models/'):

    os.makedirs(savepath, exist_ok=True)
    model = MultiHeadAttentionModel(num_enc_layers=4)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=500, verbose=True)
    task = Task(max_len=max_len, vocab_size=vocab_size)

    loss_hist = []
    for i in range(steps):
        minibatch_x, minibatch_y = task.next_batch(batchsize=batchsize)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            minibatch_x = torch.Tensor(minibatch_x)
            minibatch_y = torch.Tensor(minibatch_y)
            out, _, _ = model(minibatch_x)
            loss = F.cross_entropy(
                out.transpose(1, 2),
                minibatch_y.argmax(dim=2))
            loss.backward()
            optimizer.step()
            lr_scheduler.step(loss)
        if (i + 1) % print_every == 0:
            print("Iteration {} - Loss {}".format(i + 1, loss))
        loss_hist.append(loss.detach().numpy())

    print("Iteration {} - Loss {}".format(i + 1, loss))
    print("Training complete!")
    torch.save(model.state_dict(), savepath + '/ckpt.pt')
    return loss_hist


def test(max_len=10,
         vocab_size=3,
         pos_enc=True,
         use_multihead=True,
         savepath='models/',
         plot=True):

    model = MultiHeadAttentionModel(num_enc_layers=4)
    model.load_state_dict(torch.load(savepath + '/ckpt.pt'))
    task = Task(max_len=max_len, vocab_size=vocab_size)

    samples, labels = task.next_batch(batchsize=1)
    print("\nInput: \n{}".format(task.prettify(samples)))
    model.eval()
    with torch.set_grad_enabled(False):
        predictions, attention, input_pos_enc = model(torch.Tensor(samples))
    predictions = predictions.detach().numpy()
    predictions = predictions.argmax(axis=2)

    print("\nPrediction: \n{}".format(predictions))
    print("\nEncoder-Decoder Attention: ")
    if use_multihead:
        for h, head in enumerate(attention[0]):
            print("Head #{}".format(h))
            for i, output_step in enumerate(head):
                print("\tAttention of Output step {} to Input steps".format(i))
                print("\t{}".format([float("{:.3f}".format(step)) for step in output_step]))
    else:
        for i, output_step in enumerate(attention[0]):
            print("Output step {} attended mainly to Input steps: {}".format(i, np.where(output_step >= np.max(output_step))[0]))
            print([float("{:.3f}".format(step)) for step in output_step])
    if pos_enc:
        input_pos_enc = input_pos_enc.detach().numpy()
        print("\nL2-Norm of Input Positional Encoding:")
        print([float("{:.3f}".format(step)) for step in np.linalg.norm(input_pos_enc, ord=2, axis=2)[0]])


In [12]:
train(print_every=50, use_multihead=True, steps=4000)

Iteration 50 - Loss 1.7648378610610962
Iteration 100 - Loss 1.7817336320877075
Iteration 150 - Loss 1.7836177349090576
Iteration 200 - Loss 1.8010951280593872
Iteration 250 - Loss 1.6743237972259521
Iteration 300 - Loss 1.569726824760437
Iteration 350 - Loss 1.4176099300384521
Iteration 400 - Loss 1.1010675430297852
Iteration 450 - Loss 1.1300634145736694
Iteration 500 - Loss 0.8104339838027954
Iteration 550 - Loss 0.7003355622291565
Iteration 600 - Loss 0.6669805645942688
Iteration 650 - Loss 0.6036309003829956
Iteration 700 - Loss 0.6167134046554565
Iteration 750 - Loss 0.632865846157074
Iteration 800 - Loss 0.623589277267456
Iteration 850 - Loss 0.6303231120109558
Iteration 900 - Loss 0.5455670952796936
Iteration 950 - Loss 0.57892906665802
Iteration 1000 - Loss 0.553711473941803
Iteration 1050 - Loss 0.5698127746582031
Iteration 1100 - Loss 0.45943984389305115
Iteration 1150 - Loss 0.27888017892837524
Iteration 1200 - Loss 0.06676359474658966
Iteration 1250 - Loss 0.007871564477682

[array(2.3895428, dtype=float32),
 array(2.0231292, dtype=float32),
 array(1.921224, dtype=float32),
 array(1.980945, dtype=float32),
 array(1.936603, dtype=float32),
 array(1.880893, dtype=float32),
 array(1.9955623, dtype=float32),
 array(1.8550473, dtype=float32),
 array(2.0106003, dtype=float32),
 array(1.8659568, dtype=float32),
 array(1.9117118, dtype=float32),
 array(1.8294381, dtype=float32),
 array(1.8551674, dtype=float32),
 array(1.9098252, dtype=float32),
 array(1.88913, dtype=float32),
 array(1.7563269, dtype=float32),
 array(1.850315, dtype=float32),
 array(1.8064514, dtype=float32),
 array(1.8102348, dtype=float32),
 array(1.8777775, dtype=float32),
 array(1.796097, dtype=float32),
 array(1.8167, dtype=float32),
 array(1.7623725, dtype=float32),
 array(1.7742174, dtype=float32),
 array(1.833455, dtype=float32),
 array(1.8132381, dtype=float32),
 array(1.8387612, dtype=float32),
 array(1.8730369, dtype=float32),
 array(1.7420616, dtype=float32),
 array(1.8201734, dtype=fl

In [None]:
test()


Input: 
[[['A']
  ['B']
  ['B']
  ['C']
  ['C']
  ['C']
  ['A']
  ['B']
  ['B']
  ['B']
  ['C']
  ['B']
  ['A']]]

Prediction: 
[[2 4 4]]

Encoder-Decoder Attention: 
Head #0
	Attention of Output step 0 to Input steps
	[0.051, 0.076, 0.088, 0.077, 0.077, 0.077, 0.072, 0.083, 0.084, 0.082, 0.077, 0.082, 0.072]
	Attention of Output step 1 to Input steps
	[0.11, 0.079, 0.071, 0.073, 0.073, 0.073, 0.077, 0.074, 0.073, 0.074, 0.073, 0.074, 0.077]
	Attention of Output step 2 to Input steps
	[0.066, 0.072, 0.075, 0.08, 0.08, 0.08, 0.083, 0.075, 0.075, 0.075, 0.08, 0.075, 0.083]

L2-Norm of Input Positional Encoding:
[1.956, 1.425, 2.293, 0.401, 0.383, 0.394, 0.387, 0.402, 0.404, 0.387, 0.4, 0.386, 0.391]
