In [1]:
import nltk
from nltk.corpus import wordnet as wn
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as a
import time
import string
import random

In [2]:
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\benak\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
words = open('C:\\Users\\benak\\Documents\\More Documents\\words.txt', 'r').read().splitlines()
words[:10]

['aardvark',
 'aardwolf',
 'aaron',
 'aback',
 'abacus',
 'abaft',
 'abalone',
 'abandon',
 'abandoned',
 'abandonment']

In [5]:
# sample definitions for the word: weight
syns = wn.synsets('weight')
num = 0
for s in syns:
    num += 1
    print(f'{num}: ', f'({s.pos()})', s.definition())

1:  (n) the vertical force exerted by a mass as a result of gravity
2:  (n) sports equipment used in calisthenic exercises and weightlifting; it is not attached to anything and is raised and lowered by use of the hands and arms
3:  (n) the relative importance granted to something
4:  (n) an artifact that is heavy
5:  (n) an oppressive feeling of heavy force
6:  (n) a system of units used to express the weight of something
7:  (n) a unit used to measure weight
8:  (n) (statistics) a coefficient assigned to elements of a frequency distribution in order to represent their relative importance
9:  (v) weight down with a load
10:  (v) present with a bias


## Create Words Sample

In [6]:
wsample = []
undefined = []
ix = torch.randperm(len(words))
for ix in ix: wsample.append(words[ix])
for w in wsample:
    if len(wn.synsets(w)) < 1: 
        undefined.append(w)      # create set of words without a wordnet definition 
        wsample.remove(w)        # remove undefined words from the sample
        
print(len(undefined))
print(len(wsample))

4112
53998


## Structure Dictionary

In [7]:
# definitions for all words in wsample
definitions = [s.definition() for w in wsample for s in wn.synsets(w)]

# remove punctuations from definitions and append ' . ' sentence-end token
trim_definitions = [''.join(d).translate(str.maketrans('', '', string.punctuation)) + ' . ' for d in definitions]
trim_defstring = ''.join(trim_definitions)    # join punctuation-free definitions into a string

# create vocab list of all individual words that appear in the sample set of definitions
vocab = sorted(list((dict.fromkeys(trim_defstring.split()))))

In [8]:
print('Definitions in sample: \n ', len(definitions))
print('Distinct words in sampled definitions: \n ', len(trim_defstring.split()))
print('Unique words in sampled definitions: \n  Vocab set:', len(vocab))

Definitions in sample: 
  161831
Distinct words in sampled definitions: 
  1482928
Unique words in sampled definitions: 
  Vocab set: 29593


In [9]:
# create a list of rare words in the sample vocab set (words appearing only once in the sample of definitions)
counts = []
for word in vocab:
    counts += [trim_defstring.count(word)]

idk = []
for i in range(len(counts)):
    if counts[i] < 2: idk.append(i)
rare_words = [vocab[ix] for ix in idk]
len(rare_words)

4260

In [10]:
# remove rare words from the sample definitions
trimmer_defs = trim_defstring.split()
for w in rare_words:
    trimmer_defs.remove(w)
trimmer_defs = ' '.join(trimmer_defs).split('.')
trimmer_defs = [d for d in trimmer_defs]
trimmer_defs[:10]

['elegance by virtue of being fashionable ',
 ' elegant and stylish ',
 ' the location of something surrounded by other things ',
 ' not thin of a specific thickness or of relatively great extent from one surface to the opposite usually in the smallest of the three solid dimensions ',
 ' having component parts closely crowded together ',
 ' relatively dense in consistency ',
 ' spoken as if with a thick tongue ',
 ' having a short and solid form or stature ',
 ' hard to pass through because of dense growth ',
 ' of darkness very intense ']

In [11]:
# remove all rare words from the vocab list
trim_vocab = sorted(list((dict.fromkeys(' '.join(trimmer_defs).split()))))
print('Vocab set excl rare words:', len(trim_vocab))

Vocab set excl rare words: 25332


In [12]:
len(''.join(trimmer_defs).split())

1316837

In [13]:
end_char = '.'
start_char = '<s>'
pad_char = '<p>'

stoi = {s:i+1 for i,s in enumerate(trim_vocab)}    # word-to-integer mapping dictionary
stoi[end_char] = len(stoi) + 1                     # adding end character
stoi[start_char] = len(stoi) + 2                   # adding start character
stoi[pad_char] = 0                                 # adding pad character
itos = {i:s for s,i in stoi.items()}               # integer-to-word mapping dictionary1

enc = lambda s: [stoi[c] for c in s]            # encoder
dec = lambda l: ' '.join([itos[i] for i in l])  # decoder

## Data Population Summaries:
- Total Vocab Size: 29,593
- Vocab Size Excluding Rare Words (1 offs): 25,333
- Total Definitions: 161,831
- Total Words: 1,482,928

## Build Datasets
Naively encode the entire dataset into one tensor object - for now.

In [14]:
data = [enc(d.split()) for d in trimmer_defs]
max_length = max([len(d) for d in data])
xdat = [enc([start_char]) + d for d in data]
ydat = [d + enc([end_char]) for d in data]

# right pad all definitions to max length
for d in xdat: d += [0] * (max_length - len(d) + 1)
for d in ydat: d += [0] * (max_length - len(d) + 1)

xdat = torch.tensor(xdat)
ydat = torch.tensor(ydat)

xdat[0], ydat[0]

(tensor([25335,  9170,  5169, 24556, 16153,  4383, 10194,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0]),
 tensor([ 9170,  5169, 24556, 16153,  4383, 10194, 25333,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,

In [15]:
n = int(0.8*len(data))
Xt, Yt = xdat[:n], ydat[:n]     # 80% training data
Xv, Yv = xdat[n:], ydat[n:]     # 20% validation data

## Hyperparameters

In [33]:
# device = 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vs = len(stoi) + 1
bts = 64       # batch size
bls = max_length + 1      # block size
n_emb = 256     # embedding dimesions
n_head = 8     # number of heads per multihead stack (head_size = n_emb // n_head)
n_layer = 8    # number of decoder blocks
dropout = 0.2  # probability of zeroing-out neuron in dropouts
learning_rate = 1e-3

print(device)

cuda


## Batch Function

In [34]:
torch.manual_seed(42)

def minibatch(split):
    if split == 'train':
        xdat, ydat = Xt, Yt
    else:
        xdat, ydat = Xv, Yv
    ix = torch.randint(len(xdat) - bls, (bts,))
    
    x = xdat[ix]
    y = ydat[ix]
    x, y = x.to(device), y.to(device)
    return x, y    

xb, yb = minibatch('train')

## Self-Attention Head

In [35]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(bls, bls)))  # (T, T)
            # torch.tril() is not a parameter, so we have to use register_buffer to assign it to the module

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)    # (B, T, hs)
        q = self.query(x)  # (B, T, hs)
        
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5   # (B, T, T); scaled by 1/sqrt(hs)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)        
        
        # aggregate values by weights
        v = self.value(x)
        out = wei @ v
        return out    

## MultiHead Self-Attention Stack

In [36]:
class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])    # create list of heads
        self.proj = nn.Linear(head_size * num_heads, n_emb)    # linear transformation of the output from the head stack
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)    # feed forward through heads and concatenate output
        out = self.dropout(self.proj(out))                     # pass output through linear layer and dropout
        return out

## FeedForward Layer

In [37]:
class FeedForward(nn.Module):
    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),     # mult 4 bc the paper does a 4x channel expansion in the feedforward
            nn.ReLU(),
            nn.Linear(4 * n_emb, n_emb),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

## Decoder Block

In [38]:
class Block(nn.Module):
    def __init__(self, n_emb, n_head):
        super().__init__()
        head_size = n_emb // n_head
        self.sa = MultiHead(n_head, head_size)    # self-attention stack
        self.ffwd = FeedForward(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)   # layer normalization for self-attention stack
        self.ln2 = nn.LayerNorm(n_emb)   # layer normalization for feed forward
        
    def forward(self, x):
        x = x + self.sa(self.ln1(x))     # residual self-attention stack connection
        x = x + self.ffwd(self.ln2(x))   # residual feed-forward connection
        return x

## Model

In [39]:
class DefModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vs, n_emb)
        self.position_embedding_table = nn.Embedding(bls, n_emb)
        self.blocks = nn.Sequential(*[Block(n_emb, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_emb)   # final layer norm
        self.lm_head = nn.Linear(n_emb, vs)    # output linear layer
        
        
    def forward(self, input, targets=None):
        B, T = input.shape
        
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(input)                               # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))   # (T,C)
        x = tok_emb + pos_emb                                                     # (B,T,C)
        x = self.blocks(x)                                                        # (B,T,C)
        x = self.ln_f(x)                                                          # (B,T,C)
        logits = self.lm_head(x)                                                  # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets, ignore_index=0)    # ignore index of pad_char
            
        return logits, loss
    
    
    def generate(self, idx, samples):    # idx is (B, T) array of indices in the current context
        
        model.eval()
        sample = []
        
        for _ in range(samples):
            ctx = idx
            
            while True:
                ctx_cond = ctx[:, -bls:]
                logits, loss = self(ctx_cond)

                # focus only on the last time step
                logits = logits[:, -1, :] # becomes (B, C)

                probs = F.softmax(logits, dim=-1) # (B, C)
                ctx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

                # append sampled index to the running sequence
                ctx = torch.cat((ctx, ctx_next), dim=1) # (B, T+1)
                
                if ctx_next.item() == stoi[end_char] or ctx.shape[1] > 50:
                    break
            sample.append(dec(ctx.tolist()[0]))
        
        model.train()
        
        return sample

## Evaluation Function

In [40]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = minibatch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## Initializations

In [46]:
torch.manual_seed(1337)

model = DefModel()
model = model.to(device)

In [47]:
param_count = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        # print(name, param.data.shape.numel())
        param_count += param.data.shape.numel()
f"{param_count:,.0f} total parameters"

'19,327,224 total parameters'

## Model Training

In [48]:
import time

max_iters = 10000
eval_iters = 500
tloss, vloss = [], []

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    start = time.time()
    
    # every once in a while evaluate the loss on train and val sets
    if (iter == 1) or (iter > 0 and iter % eval_iters == 0):
        losses = estimate_loss()
        tloss.append(losses['train'])
        vloss.append(losses['val'])
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} \
              | ETA: {run_time / 60 * (max_iters-iter):.2f} min")

    # sample a batch of data
    xb, yb = minibatch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    run_time = time.time() - start

step 1: train loss 9.2254, val loss 9.2226               | ETA: 270.70 min


KeyboardInterrupt: 

So, the training loss and validation loss diverge. We clearly suffer from overfitting. The dataset still has many low frequency words - there are certainly words included in the validation set that are not included in the training set. As we become better at producing predictions for words (and definitions) in the training set, the impact of these unseen words within the validation set becomes more disruptive.

## Generate Samples from the Model

In [327]:
context = torch.tensor([[stoi[start_char]]], device=device)
model.generate(context, samples=20)

['<s> someone who flees from an uncongenial situation .',
 '<s> issue commands or orders for .',
 '<s> United States manufacturer of automobiles who pioneered mass production .',
 '<s> a joyful occasion for special festivities to mark some happy event .',
 '<s> stop amount .',
 '<s> make similar in sound .',
 '<s> hot or cold alcoholic mixed drink containing a beaten egg .',
 '<s> lift and laborious because of restraint or sensation .',
 '<s> transfer too much .',
 '<s> be owned by be in the possession of .',
 '<s> a movement downward .',
 '<s> running lengthwise .',
 '<s> a disposition to exhibit uncontrolled anger .',
 '<s> a humorous anecdote or remark intended to provoke laughter .',
 '<s> state or say further .',
 '<s> a poem consisting of 3 stanzas and an envoy .',
 '<s> attentively .',
 '<s> protection from harm .',
 '<s> a woman who works the right for a game .',
 '<s> plot a map of land .']

### Full Model Sample

In [375]:
context = torch.tensor([[stoi[start_char]]], device=device)
model.generate(context, samples=20)

['<s> inquire into .',
 '<s> increase in value or to a higher point .',
 '<s> be adjacent or come together .',
 '<s> likely to attract attention .',
 '<s> not having a roof .',
 '<s> put forward as of an idea .',
 '<s> incapable of being .',
 '<s> protect or defend a position in order to reach a game .',
 '<s> use dental floss to clean .',
 '<s> composing as in an idea .',
 '<s> a farm that gathers .',
 '<s> computer science a set of data on which samples changes feed or a number of issues of data is changed from which come to which one previously been paid .',
 '<s> equip with a fuse provide with a fuse .',
 '<s> move through by or as if by whistling .',
 '<s> deem wrong or inappropriate .',
 '<s> the act of communicating with a deity especially as a petition or in adoration or contrition or thanksgiving .',
 '<s> make a rupture in the ranks of the enemy or ones own by quitting or fleeing .',
 '<s> have as a as a will hold .',
 '<s> the scent of a greasy glandular secretion from the m