In [None]:
# Encoder/decoder transformer, like in the paper

# Single Encoder block

# Embed input into a length 512 vector
# Add positional encoding to the input
# Alternating sin and cosine functions across the embedding dimension (512)
# i = position in the embedding dimension
#PE(pos,2i) = sin(pos/100002i/dmodel )
#PE(pos,2i+1) = cos(pos/100002i/dmodel
# Sum positional encoding and input
# apply dropout with chance .1

# Self-attention
# Each token in the input sequence is a query
# Every other token is a key
# Project query and key into a new dimensional space with a linear transform (this is multi-head attention)
# We do the following 16 times for multi-head attention
# Dot the projects query with the projected key matrix
# QK
# Attention(Q, K, V ) = softmax( QKT / √dk)V
# √dk is the square root of the value dimension
# Keys and values are the same in this particular case
# Basically, compute an attention score (scalar) between each query and each key, then scale the value by that number
# So less relevant other tokens are minimized
# Concat the results of the attention equation
# Run through another linear layer to reproject
# Each attention head outputs 1/16th of the input embedding len

# After doing multi-head attention (16)
# Apply dropout with chance .1
# Add the input to the layer (original embedded sequences) and the output of attention
# Run layer normalization (unclear which layer norm to use)

# Run a feed forward network
# Add input to the layer to the output of the ff network
# Normalize again

# Single decoder block

# Shift outputs right, to start with start token
# Do positional encoding
# Do dropout with chance .1
# Mask outputs, so queries can only see keys that came before the query
# Run multi-head attention
# Add and norm

# Run multi-head attention again, but this time v,k is from encoder stack, and q is from decoder stack
# Apply dropout with chance .1
# When doing add and norm, add in the decoder stack input

# Feed forward

# At top of stack, do another linear layer and softmax

# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf
# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf

In [53]:
import numpy as np
import torch
from torch import nn
import functorch
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1
SP_VOCAB_SIZE = 1000
TRAIN_SIZE = 500

In [54]:
from text_data import CNNDatasetWrapper

class Wrapper(CNNDatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 15
    target_length = 15

wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)

datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 42.60it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn_dailymail
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piec

In [55]:
x, y, prev_y = train.dataset[0]

In [56]:
embed = nn.Embedding(wrapper.vocab_size, 512)

In [82]:
def positional_encodings(seq_len, embed_len):
    #PE(pos,2i) = sin(pos/10000 ^ 2i/dmodel )
    #PE(pos,2i+1) = cos(pos/10000 ^ 2i/dmodel
    # Pos 1 will embed to a vector, pos 2 will embed to a vector
    # Adjacent vector positions will form a sin wave
    # plt.plot(np.arange(0,100), np.sin(np.arange(0,100) / np.power(10000.0, (20.0/512))))
    encodings = np.zeros((seq_len, embed_len), dtype=np.float32)
    for i in range(seq_len):
        evens = np.arange(0, embed_len, 2)
        odds = np.arange(1, embed_len, 2)
        all = np.power(np.full(embed_len, 10000), np.arange(0, embed_len) / embed_len)

        sin_embed = np.sin(i/all)
        cos_embed = np.cos(i/all)

        sin_embed[odds] = 0
        cos_embed[evens] = 0
        encodings[i,:] = np.sum((sin_embed, cos_embed), axis=0)
    return encodings

In [83]:
x_embed = embed(x)
pos_encoding = torch.from_numpy(positional_encodings(x_embed.shape[0], 512))
network_in = x_embed + pos_encoding

In [219]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_units, attention_heads, mask=False):
        super(MultiHeadAttention, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.head_units = int(input_units/attention_heads)
        self.mask = mask
        if self.head_units * self.attention_heads != self.input_units:
            raise Exception("Invalid input units and heads combo")

        k = math.sqrt(1/self.head_units)
        self.query_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)
        self.key_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)
        self.value_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)

    def forward(self, queries, context):
        exp_queries = queries.repeat(self.attention_heads, 1, 1)
        exp_context = context.repeat(self.attention_heads, 1, 1)
        queries = torch.bmm(exp_queries, self.query_weights)
        keys = torch.bmm(exp_context, self.key_weights)
        values = torch.bmm(exp_context, self.value_weights)

        # Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
        attention = torch.bmm(queries, keys.swapaxes(1,2)) / np.sqrt(self.head_units)
        if self.mask:
            # TODO: Mask attention in here
            # You can do it by setting attentions to 0 where the query doesn't know about the key (ie, key comes after query)
            # this should zero out the values at those points
            pass
        attention = torch.softmax(attention, dim=2)
        weighted_values = torch.bmm(attention, values)
        weighted_values = weighted_values.swapaxes(0,1).reshape(x.shape[0], -1)
        return weighted_values

In [220]:
class EncoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048):
        super(EncoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.mha = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropout = nn.Dropout(.1)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()

    def forward(self, x):
        weighted_values = self.dropout(self.mha(x, x))
        attn_output = x + weighted_values
        # TODO: Add layer normalization on top of attn_output
        reprojected = self.dropout(self.linear2(self.relu(self.linear1(attn_output))))
        block_output = attn_output + reprojected
        # TODO: Add layer normalization on top of block_output
        return block_output

In [221]:
class DecoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048):
        super(DecoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.in_attn = MultiHeadAttention(self.input_units, self.attention_heads, mask=True)
        self.context_attn = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropout = nn.Dropout(.1)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()

    def forward(self, queries, context):
        weighted_values = self.dropout(self.in_attn(queries, queries))
        attn_output = x + weighted_values
        # TODO: Add layer normalization on top of attn_output
        decoder_values = self.dropout(self.context_attn(queries, context))
        decoder_output = attn_output + decoder_values
        # TODO: Add layer normalization
        reprojected = self.dropout(self.linear2(self.relu(self.linear1(decoder_output))))
        block_output = attn_output + reprojected
        # TODO: Add layer normalization on top of block_output
        return block_output

In [222]:
class Transformer(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, blocks=2):
        super(Transformer, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.attention_heads = attention_heads
        self.blocks = blocks

        self.input_embedding = nn.Embedding(input_units, hidden_units)
        self.dropout = nn.Dropout(.1)
        self.encoders = nn.ModuleList(EncoderBlock(hidden_units, attention_heads) for _ in range(self.blocks))

    def forward(self, x):
        embedded = self.input_embedding(x)
        pos_encoding = torch.from_numpy(self.encoding(x_embed.shape[0], self.hidden_units))
        network_in = self.dropout(embedded + pos_encoding)
        block_outputs = network_in.unsqueeze(0)

        for i in range(self.blocks):
            block_output = self.encoders[i](block_outputs[i])
            block_outputs = torch.cat((block_outputs, block_output.unsqueeze(0)), dim=0)

        return block_outputs[-1,]

    def encoding(self, seq_len, embed_len):
        #PE(pos,2i) = sin(pos/10000 ^ 2i/dmodel )
        #PE(pos,2i+1) = cos(pos/10000 ^ 2i/dmodel
        # Pos 1 will embed to a vector, pos 2 will embed to a vector
        # Adjacent vector positions will form a sin wave
        # plt.plot(np.arange(0,100), np.sin(np.arange(0,100) / np.power(10000.0, (20.0/512))))
        encodings = np.zeros((seq_len, embed_len), dtype=np.float32)
        for i in range(seq_len):
            evens = np.arange(0, embed_len, 2)
            odds = np.arange(1, embed_len, 2)
            all = np.power(np.full(embed_len, 10000), np.arange(0, embed_len) / embed_len)

            sin_embed = np.sin(i/all)
            cos_embed = np.cos(i/all)

            sin_embed[odds] = 0
            cos_embed[evens] = 0
            encodings[i,:] = np.sum((sin_embed, cos_embed), axis=0)
        return encodings

In [223]:
tf = Transformer(wrapper.vocab_size, 512, 8)
tf(x)

tensor([[ 5.2130, -2.7786, -4.7501,  ...,  0.9715,  8.7992,  1.6567],
        [ 5.4114, -0.0718, -1.0579,  ...,  5.5354, -0.7851,  5.4303],
        [12.2055,  2.7835, -4.2664,  ..., -0.0511,  4.1057,  0.2862],
        ...,
        [ 4.1110, -0.3995, -1.9013,  ...,  0.3918, -2.6225,  4.9806],
        [15.6233, -3.2491, -7.8968,  ..., -0.4187,  7.9496,  3.3613],
        [17.4059,  1.3717, -8.8419,  ...,  0.1119,  8.1356,  2.6065]],
       grad_fn=<SelectBackward0>)

In [110]:
# single attention head
# Project into a smaller space
query_embed = nn.Linear(512, 32)
key_embed = nn.Linear(512, 32)
values_embed = nn.Linear(512, 32)

queries = query_embed(network_in)
keys = key_embed(network_in)
values = values_embed(network_in)

attention = torch.softmax((queries @ keys.T) / np.sqrt(keys.shape[1]), dim=1)
summed_attention = torch.sum(attention, dim=0).view(1,-1)
weighted_values = summed_attention @ values

In [159]:
query_weights = nn.Parameter(torch.rand(8, 512, 64))

expanded_x = network_in.repeat(8, 1, 1)
queries = torch.bmm(expanded_x, query_weights)
keys = torch.bmm(expanded_x, query_weights)
values = torch.bmm(expanded_x, query_weights)

# Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
attention = torch.softmax(torch.bmm(queries, keys.swapaxes(1,2)) / np.sqrt(64), dim=2)