## Imports

In [3]:
import nltk
assert(nltk.download('wordnet'))
from nltk.corpus import wordnet as wn
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as a
import time
import string

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\benak\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Data Preprocessing

Our training data will come from Wordnet, a large lexical database of english words and accompanying definitions. We use the natural language toolkit to interact conveniently with wordnet via `nltk.corpus.wordnet`.
- Source: https://wordnet.princeton.edu/

In [6]:
# yield sample definitions from word "weights"
for i,s in enumerate(wn.synsets('weights')):
    print(f'{i}: ({s.pos()}) {s.definition()}')

0: (n) the vertical force exerted by a mass as a result of gravity
1: (n) sports equipment used in calisthenic exercises and weightlifting; it is not attached to anything and is raised and lowered by use of the hands and arms
2: (n) the relative importance granted to something
3: (n) an artifact that is heavy
4: (n) an oppressive feeling of heavy force
5: (n) a system of units used to express the weight of something
6: (n) a unit used to measure weight
7: (n) (statistics) a coefficient assigned to elements of a frequency distribution in order to represent their relative importance
8: (v) weight down with a load
9: (v) present with a bias


In [7]:
# read in sample words (from Kaggle dataset: https://www.kaggle.com/datasets/ruchi798/part-of-speech-tagging/data)
wsample = []
with open('word_sample.txt', 'r') as file:
  for line in file.read().splitlines():
    wsample.append(line)

# definitions for all words in wsample
definitions = [s.definition() for w in wsample for s in wn.synsets(w)]

# remove punctuations from definitions and append ' . ' sentence-end token
trim_definitions = [''.join(d).translate(str.maketrans('', '', string.punctuation)) + ' . ' for d in definitions]
trim_defstring = ''.join(trim_definitions)    # join punctuation-free definitions into a string

# create vocab list of all individual words that appear in the sample set of definitions
vocab = list(set(sorted(trim_defstring.split())))

print(f'Definitions in sample: {len(definitions):,.0f}')
print(f'Distinct words in sampled definitions: {len(trim_defstring.split()):,.0f}')
print(f'Unique words in sampled definitions: {len(vocab):,.0f}')

Definitions in sample: 244,577
Distinct words in sampled definitions: 2,343,437
Unique words in sampled definitions: 41,151


Now we remove rare words from the definitions. These are defined as words that appear only once. They tend to be overly specific pronouns or dates.

In [8]:
counts = [trim_defstring.count(word) for word in vocab]
rare_words = {vocab[i] for i,c in enumerate(counts) if c < 2}

# remove rare words from the definitions
more_trim_defs= []
for i,d in enumerate(trim_definitions):
    if len(set(d.split()) & rare_words) == 0:
        more_trim_defs.append(d)

# remove rare words from vocab
trim_vocab = list(set(vocab) - set(rare_words))

In [9]:
len(trim_vocab), len(more_trim_defs), len(''.join(more_trim_defs).split(' '))

(34601, 239152, 2279897)

In summary we have:
- 34,601 unique words in the dataset after removing one-offs
- 239,152 definitions
- 2,279,897 distinct words in the definitions

### Prepare Training and Validation Datasets

Prepare encoding and decoding dictionaries:

In [10]:
end_char = '.'
start_char = '<s>'
pad_char = '<p>'

stoi = {s:i+1 for i,s in enumerate(trim_vocab)}    # word-to-integer mapping dictionary
stoi[end_char] = len(stoi) + 1                     # adding end character
stoi[start_char] = len(stoi) + 2                   # adding start character
stoi[pad_char] = 0                                 # adding pad character
itos = {i:s for s,i in stoi.items()}               # integer-to-word mapping dictionary

encoder = lambda s: [stoi[c] for c in s]            # encoder
decoder = lambda l: ' '.join([itos[i] for i in l])  # decoder

For now, we will naively encode the entire dataset into one tensor object.

In [11]:
data = [encoder(d.split()) for d in more_trim_defs]
max_length = max([len(d) for d in data])
xdat = [encoder([start_char]) + d[:-1] for d in data]
ydat = [d for d in data]

# right pad all definitions to max length
for d in xdat: d += [0] * (max_length - len(d) + 1)
for d in ydat: d += [0] * (max_length - len(d) + 1)

xdat = torch.tensor(xdat)
ydat = torch.tensor(ydat)

In [12]:
n = int(0.8*len(data))
Xt, Yt = xdat[:n], ydat[:n]     # 80% training data
Xv, Yv = xdat[n:], ydat[n:]     # 20% validation data

## Initialize Hyperparameters

Set seed for reproducibility:

In [13]:
torch.manual_seed(42);  # stay calm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = len(stoi) + 1
batch_size = 128
block_size = max_length + 1
n_emb = 256           # embedding dimesions
n_head = 8            # number of heads per multihead stack (head_size = n_emb // n_head)
n_blocks = 8          # number of decoder blocks
dropout = 0.2         # probability of zeroing-out neuron in dropouts

print(device)

cuda


## Batch Function

In [14]:
def minibatch(xdat, ydat):
    # 1D tensor of random ints between bls and len(xdat) of length bts
    idx = torch.randint(len(xdat) - block_size, (batch_size,))

    # index into x and y tensors using random ints
    x, y = xdat[idx], ydat[idx]
    return x.to(device), y.to(device)

xb, yb = minibatch(Xt, Yt)
xb.shape, yb.shape

(torch.Size([128, 71]), torch.Size([128, 71]))

## Self-Attention

### Breaking Down Self-Attention

Self attention is the crux of this whole thing. So, I want to investigate it in more detail before putting together the model's self-attention head.

Attention may be thought of as simply producing a weighted average of previous tokens in the training data. "Previous" meaning the previous tokens within the context length before each token in the training data. The weights applied to the embedded value of each previous token should be trainable *and* should be informed by the current token's embedded value. For example, if the current token is a noun then we may want the weights applied to preceding adjective to be greater than the weights appliead to preceding nouns. However, if the token is a verb then the weights applied to previous adjectives should be low while the weights applied to previous nouns should be higher.

Let's first look at how we can compute these rolling weighted averages over a certain context length in a vectorized format:

In [146]:
B, T, C = 4, 8, 32  # batch, time (context length), channels

weights = torch.ones(T, T)     # initialize tensor of ones (T, T)
weights = torch.tril(weights)  # zero-out the upper triangle of the weights matrix
weights = weights.masked_fill(weights == 0, float('-inf'))    # fills all 0 elements in tril with '-inf'
weights = F.softmax(weights, dim=1)    # softmax transform  over the rows of weights

weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Through matrix operations and a softmax transformation we've produced a matrix of uniform weights. Note that as each row sums to one, they may be used as weightsghted averages. This weights matrix would apply uniform weights across previous tokens up to eight tokens in the past. Note also that the zeros in the upper triangle ensure that *future tokens* do not influence the weights for the current token.

Now, we need to provide a means of producing *trainable* non-uniform weights so that different tokens within the context can provide different weights to the current token. This is achieved by implementing two new matrices: a *key* and a *query* matrix. These will be initialized a linear layers through pytorch so that we can backpropogate through them to tune their weights. A token's query interacts with all the other keys through a dot-product. This dot-product replaces the simple rolling average that we constructed before (although it is the same in form and similar in function).

This allows us to train the queries to respond differently to different keys. Certain 'alignments' between a query and a key will be promoted in training, leading to improved conditional predictions.

In [170]:
x = torch.randn(B, T, C)  # initialize some random input data

# single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)    # (B, T, 16)
q = query(x)  # (B, T, 16)

weights = q @ k.transpose(-2, -1)  # transpose of the last two dimensions of k (B, T, 16) --> (B, 16, T) --> q @ k (B, T, T)
weights = weights * head_size**-0.5
tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))   # zero-out future positions
weights = F.softmax(weights, dim=-1)    # softmax over the last dimension of weights

weights[0]


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6431, 0.3569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3625, 0.1669, 0.4706, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2802, 0.2541, 0.2505, 0.2153, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1296, 0.3602, 0.1551, 0.1921, 0.1631, 0.0000, 0.0000, 0.0000],
        [0.1554, 0.2850, 0.0848, 0.1227, 0.1966, 0.1555, 0.0000, 0.0000],
        [0.0807, 0.1930, 0.2052, 0.0929, 0.1514, 0.1391, 0.1377, 0.0000],
        [0.1535, 0.0653, 0.2159, 0.1004, 0.0793, 0.2435, 0.0539, 0.0883]],
       grad_fn=<SelectBackward0>)

In [171]:
(weights @ x)[0][0]

tensor([-0.6608,  0.4298,  0.1429,  1.3202,  1.0733,  1.0446,  0.9047, -0.2830,
        -0.0591,  1.3761,  0.9947, -1.7909, -0.2628,  1.9571,  0.0146, -0.3971,
        -0.2107,  0.9406,  0.3464, -1.3773, -0.8786, -0.3244,  0.4048,  0.1354,
         0.3485,  0.0094,  1.5291, -0.6088, -0.9659,  1.5619, -0.5781,  0.9003],
       grad_fn=<SelectBackward0>)

Typically `x` will not be multiplied with the weights matrix directly, but through a *value* linear layer that is unique to the specific attention head. One benefit of this is that it aligns the shape of the weighted attention output with the head size rather than relying on the shape of the input data. E.g.: 

In [172]:
value = nn.Linear(C, head_size, bias=False)

print('Pre linear layer: ', (weights @ x).shape)
print('Post linear layer: ', (weights @ value(x)).shape)

Pre linear layer:  torch.Size([4, 8, 32])
Post linear layer:  torch.Size([4, 8, 16])


Mathematically, the value layer is more significant. The value matrix applies another set of trainable weights to the input data such that the resulting matrix of $xV$ has contextual information from every other element in $x$ for each elementy in $x$. Masking the upper triangle with zeros via the weights matrix removes the contextual information from future elements. For example, this enables more meaningful weights to be applied not just to a noun like "wood" but to a noun in the context of its sentence, e.g. "mossy dark wood". This context is the vectorized representation of wood with values that encode information about the words (tokens) that precede it.

Finally, there is a simple standardization step that is included in the *"Attention is all You Need"* (2017) paper. That is, to scale the weights by the inverse of the square of the head size prior to applying the softmax transformation. From the paper:
$$\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$ 
Where $Q$ is the query matrix, $K$ is the key matrix, $V$ is the value matrix, and $d_k$ is the head size.

This is done to standardize the weights before applying the softmax, promoting stability. Let's look briefly at the variance of the weights before and after this transformation for `k` and `q` both with unit variance:

In [175]:
# k = torch.randn(B, T, head_size)
# q = torch.randn(B, T, head_size)

print('k var', k.var().item())
print('q var', q.var().item())

weights = q @ k.transpose(-2, -1)
print(weights.var().item())

print((weights * head_size**-0.5).var())

k var 1.0249735116958618
q var 0.9907991290092468
16.46785545349121
1.0292409658432007


Basically, the $\frac{1}{\sqrt{d_k}}$ modification ensures that the scale of the variance of the weights matrix is on the same order of magnitude as the scale of the variances for the query and key matrices. This is important for the softmax transformation because the softmax will exaggerate the probabilities associated with large positive values. At initialization, if we happen to have very large positive values, then a feedback-loop may be created wherein the softmax transformed weights matrix converges to a one-hot-encoding.

### Further Notes on Attention

**Attention as a Communication Mechanism** \
All of these tokens are like nodes in a directed acyclic graph. Each node aggregates the information from every other node that points towards it (not the future nodes) through a weightsghted sum with data-dependent weights.

**On Spatial Information** \
The data-dependent weightsghted averages that drive this attention mechanism provide no information about the position of a node (token). The value of the weightsghted average doesn't change if we change the ordering of the values and their weights. This is why we encode position information into a positional embedding space.

**On Batch Dependence** \
Elements in each batch sample are completely independent of elements in other batch samples.

**"Encoder" Attention Block** \
Encoder attention blocks remove the constraints on temporal information sharing (the tril() statement) and allow all nodes to talk to eachother. In this case the graph could be cyclic.

The attention block that we have created is called a "Decoder" because it has the triangular masking. Decoder attention blocks are widely seen in autoregressive applications like this one.

**"Self Attention" vs "Cross Attention"** \
Self attention just means that the keys, queries, and values all come from the same source - `x`. Cross attention provides keys and values from different sources, like encoder blocks.

### Self Attention Head

Finally, the single self-attention head used in the model:

In [15]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))  # (T, T)
            # torch.tril() is not a parameter, so we have to use register_buffer to assign it to the module

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)    # (B, T, hs)
        q = self.query(x)  # (B, T, hs)
        
        # compute attention scores ("affinities")
        weights = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5   # (B, T, T); scaled by 1/sqrt(hs)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)        
        
        # aggregate values by weights
        v = self.value(x)
        out = weights @ v
        return out    

## Multi-Headed Attention

*Note:* This summary involves a lot of discussion about embedding and embedding spaces without explaining those concepts. It's really more for organizing my thoughts than anything and isn't meant to be a rigorous explanation.

Let's summarize some things first. Within an attention head *i* we have a distinct pair of key $K^{(i)}$ and query $Q^{(i)}$ matrices. Which produce a distinct attention pattern, a weights matrix $W_{K,Q}^{(i)}$. This attention matrix is multiplied by a sequence of contextualized value vectors unique to the head, $V^{(i)}$. The resulting matrix, $W_V^{(i)}$ may be thought of as a set of column vectors. Each column vector applies a transformation to its corresponding column in the embedded input data, moving it in the embedding space. 

Now, when we parrallelize this process across mulitple attention heads, we yield multiple transformation matrices. These transformation matrices may move the input data in different ways. For example, for a vector of the token "ball", one transformation may point in directions of "sports", "activities", "toys", and so on, while another points in directions of "events", "dining", "galas"... The multiple heads can attenuate on different contextual meanings of tokens more specifically than one head alone could. In as sense, we're splitting our attention. ;)

Essentially, we are enabling the model to better learn the many distinct ways in which context may change meaning.

The transformations from the multiple heads can then be aggregated to yield final predictions. In the paper, the results of each attention head are simply concatenated together.

### Multi-Head Self-Attention Stack

In [16]:
class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for h in range(num_heads)])    # create list of heads
        self.proj = nn.Linear(head_size * num_heads, n_emb)    # linear transformation of the output from the head stack
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)    # feed forward through heads and concatenate output
        out = self.dropout(self.proj(out))                     # pass output through linear layer and apply dropout
        return out

## Feed-Forward Layer

After passing through the multi-headed attention stack we will feed forward through a fairly simple MLP. This replicates the MLP architecture in the *"Attention is all You Need"* (2017) paper.

In [17]:
class FeedForward(nn.Module):
    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),     # mult 4 bc the paper does a 4x channel expansion in the feedforward
            nn.ReLU(),
            nn.Linear(4 * n_emb, n_emb),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

## Decoder Block

Now, we'll wrap up the multi-head attention stack and the feed-forward layer into a decoder block. This blocks can be conveniently chained together to augment the model's length. Each block is fundamentally just the masked multi-headed attention stack and the feed-forward layer from the paper.

In [18]:
class Block(nn.Module):
    def __init__(self, n_emb, n_head):
        super().__init__()
        head_size = n_emb // n_head
        self.sa = MultiHead(n_head, head_size)  # self-attention stack
        self.ffwd = FeedForward(n_emb)          # feed-forward layer
        self.ln1 = nn.LayerNorm(n_emb)   # layer normalization for self-attention stack
        self.ln2 = nn.LayerNorm(n_emb)   # layer normalization for feed forward
        
    def forward(self, x):
        # x = self.ln1(x + self.sa(x))     # residual self-attention stack connection
        # x = self.ln2(x + self.ffwd(x))   # residual feed-forward connection
        x = x + self.sa(self.ln1(x))     # residual self-attention stack connection
        x = x + self.ffwd(self.ln2(x))   # residual feed-forward connection
        return x

## Full Model

Putting it all together:
Note: The paper applies dropout to the embeddings before forwarding through the model; I ignore that for now.

In [19]:
class DabbleBot(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb)      # token embedding
        self.position_embedding_table = nn.Embedding(block_size, n_emb)   # positional embedding
        self.blocks = nn.Sequential(*[Block(n_emb, n_head=n_head) for b in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_emb)                # final layer norm
        self.lm_head = nn.Linear(n_emb, vocab_size)    # output linear layer
        
        
    def forward(self, input, targets=None):
        B, T = input.shape
        
        # idx and targets are both (B, T) tensor of integers                    dimension tracking:
        tok_emb = self.token_embedding_table(input)                               # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))   # (T,C)
        x = tok_emb + pos_emb                                                     # (B,T,C)
        x = self.blocks(x)                                                        # (B,T,C)
        x = self.ln_f(x)                                                          # (B,T,C)
        logits = self.lm_head(x)                                                  # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets, ignore_index=0)    # ignore index of pad_char
            
        return logits, loss
    
    
    def generate(self, idx, samples):    # idx is (B, T) array of indices in the current context
        
        self.eval()
        sample = []
        
        for s in range(samples):
            ctx = idx
            
            while True:
                ctx_cond = ctx[:, -block_size:]
                logits, loss = self(ctx_cond)

                # focus only on the last time step
                logits = logits[:, -1, :] # becomes (B, C)

                probs = F.softmax(logits, dim=-1) # (B, C)
                ctx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

                # append sampled index to the running sequence
                ctx = torch.cat((ctx, ctx_next), dim=1) # (B, T+1)
                
                if ctx_next.item() == stoi[end_char] or ctx.shape[1] > 50:
                    break
            sample.append(decoder(ctx.tolist()[0]))
        
        self.train()
        
        return sample

### Evaluation Function

In [20]:
@torch.no_grad()
def estimate_loss(model, iters):
    out = {}
    data = {'train': (Xt, Yt), 'val': (Xv, Yv)}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(iters)
        for k in range(iters):
            X, Y = minibatch(*data[split])
            logits, loss = model.forward(input=X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## Model Training

### Initializations

In [22]:
torch.manual_seed(42)

model = DabbleBot()
model = model.to(device)

learning_rate = 1e-3
max_iters = 10000
eval_iters = 500
tloss, vloss = [], []

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)  # paper uses Adam optimizer

### Parameters

In [24]:
param_count = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        # print(name, param.data.shape.numel())
        param_count += param.data.shape.numel()
f"{param_count:,.0f} total parameters"

'24,082,476 total parameters'

### Training Loop

In [396]:
for iter in range(max_iters):
    start = time.time()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_iters == 0:
        
        losses = estimate_loss(model, eval_iters)
        tloss.append(losses['train'])
        vloss.append(losses['val'])
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} \
              | ETA: {run_time / 60 * (max_iters-iter):.2f} min")

    # sample a batch of data
    xb, yb = minibatch(Xt, Yt)

    # evaluate the loss
    logits, loss = model.forward(input=xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    run_time = time.time() - start

check in
losses checked
waaaaa
step 0: train loss 10.5989, val loss 10.5344               | ETA: 558.79 min
check in
losses checked
waaaaa
step 500: train loss 4.9537, val loss 5.7206               | ETA: 544.95 min


KeyboardInterrupt: 