## Setup
- define tiny corpus
- lowercase + whitespace tokenize
- build vocabulary and mappings:
    - stoi (string -> int ID)
    - itos (int ID -> string)
    - these IDs are exactly what nn.Embedding will lookup

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
from itertools import chain
import random

random.seed(0)
torch.manual_seed(0)

corpus = [
    "we all live in a yellow submarine",
    "we all live in a blue submarine",
    "we all love bright yellow flowers",
    "they all live in a green house",
    "they all love blue flowers",
]

# simple whitespace tokenizer
def tokenize(s):
    return s.lower().split()

tokens = [tokenize(s) for s in corpus]
print(f"tokens:\n{tokens}\n")

all_tokens = list(chain.from_iterable(tokens))
print(f"all tokens:\n{all_tokens}\n")

# unique tokens sorted for reproducibility
vocab = sorted(Counter(all_tokens).keys())
print(f"sorted unique tokens (vocab):\n{vocab}\n")

# string -> id
stoi = {w: i for i, w in enumerate(vocab)}
print(f"string to id mapping:\n{stoi}\n")

# id -> string
itos = {i: w for w, i in stoi.items()}
print(f"id to string mapping:\n{itos}\n")

V = len(vocab)
print(f"vocab length (V):\n{V}")

tokens:
[['we', 'all', 'live', 'in', 'a', 'yellow', 'submarine'], ['we', 'all', 'live', 'in', 'a', 'blue', 'submarine'], ['we', 'all', 'love', 'bright', 'yellow', 'flowers'], ['they', 'all', 'live', 'in', 'a', 'green', 'house'], ['they', 'all', 'love', 'blue', 'flowers']]

all tokens:
['we', 'all', 'live', 'in', 'a', 'yellow', 'submarine', 'we', 'all', 'live', 'in', 'a', 'blue', 'submarine', 'we', 'all', 'love', 'bright', 'yellow', 'flowers', 'they', 'all', 'live', 'in', 'a', 'green', 'house', 'they', 'all', 'love', 'blue', 'flowers']

sorted unique tokens (vocab):
['a', 'all', 'blue', 'bright', 'flowers', 'green', 'house', 'in', 'live', 'love', 'submarine', 'they', 'we', 'yellow']

string to id mapping:
{'a': 0, 'all': 1, 'blue': 2, 'bright': 3, 'flowers': 4, 'green': 5, 'house': 6, 'in': 7, 'live': 8, 'love': 9, 'submarine': 10, 'they': 11, 'we': 12, 'yellow': 13}

id to string mapping:
{0: 'a', 1: 'all', 2: 'blue', 3: 'bright', 4: 'flowers', 5: 'green', 6: 'house', 7: 'in', 8: 'live

## Embedding as one_hot x table and as row lookup
one-hot encoding - a way to represents categorical values (like words) as numeric vectors
- e.g. V = 5
    - 0: we
    - 1: all
    - 2: live
    - 3: yellow
    - 4: flowers
- yellow (ID 3) would be represented as one_hot = [0, 0, 0, 1, 0]
    - a vector of length V
    - all zeros except a 1 at the index of the word's ID

one-hot x table
- e.g. embedding table (V, D) where V = 5 and D = 3
```
W =
[ 0.2,  0.5, -0.1 ]   # row 0 -> "we"
[-0.4,  0.1,  0.8 ]   # row 1 -> "all"
[ 0.9, -0.7,  0.3 ]   # row 2 -> "live"
[ 0.0,  1.2, -0.6 ]   # row 3 -> "yellow"
[-0.2,  0.3,  0.5 ]   # row 4 -> "flowers"
```
- one_hot x W = [0, 0, 0, 1, 0] x W = [0.0, 1.2, -0.6]
    the enbedding vecotr for yellow

- So, the embedding vector for a token ID is just:
    - the row of a (V, D) table
    - equivalently, a matrix multiply of a one-hot vector with that table

In [3]:
import torch

D = 6
W = torch.randn(V, D)

def one_hot(index, size):
    v = torch.zeros(size)
    v[index] = 1.0
    return v

token = "yellow" if "yellow" in stoi else list(stoi.keys())[0]
tid = stoi[token]

# create one-hot vectro
oh = one_hot(tid, V)    # (V,)

# method 1 via matrix multiplication
via_mm = oh @ W         # (D,): one-hot times table
# method 2 via row lookup
via_row = W[tid]        # (D,): direct row gather

print("Token:", token, "| ID:", tid)
print("Table shape (V, D):", tuple(W.shape))
print("via_mm shape:", via_mm.shape, "| via_row shape:", via_row.shape)

# check numerical equality
print("Equal (allclose)?", torch.allclose(via_mm, via_row))

# peek at the vector
print("\nEmbedding vector (first 3 dims):")
print("via_mm :", via_mm[:3])
print("via_row:", via_row[:3])

Token: yellow | ID: 13
Table shape (V, D): (14, 6)
via_mm shape: torch.Size([6]) | via_row shape: torch.Size([6])
Equal (allclose)? True

Embedding vector (first 3 dims):
via_mm : tensor([ 0.0335,  0.7101, -1.5353])
via_row: tensor([ 0.0335,  0.7101, -1.5353])


## Implement MyEmbedding and compare to nn.Embedding
- a custom embedding layer is just:
    - a trainable table (V, D)
    - a row gather

In [4]:
import torch
import torch.nn as nn

class MyEmbedding(nn.Module):
    
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
    
    def forward(self, ids):
        return self.weight[ids]

# reuse V, choose embedding dim
D = 6

my_emb = MyEmbedding(V, D)      # my embedding
pt_emb = nn.Embedding(V, D)     # PyTorch embedding

# for fair comparison, make them start from the same weights
with torch.no_grad():
    my_emb.weight.copy_(pt_emb.weight)

example_ids = torch.tensor([stoi.get("we", 0), stoi.get("yellow", 0), stoi.get("flowers", 0)])

# Python calls nn.Module.__call__ which calls the forward method
out_my = my_emb(example_ids)
out_pt = pt_emb(example_ids)

print("Shapes equal? ", out_my.shape == out_pt.shape)
print("Values allclose? ", torch.allclose(out_my, out_pt))
print("Output (first row, first 3 dims):")
print("MyEmbedding:", out_my[0, :3])
print("nn.Embedding:", out_pt[0, :3])

Shapes equal?  True
Values allclose?  True
Output (first row, first 3 dims):
MyEmbedding: tensor([ 0.5588,  0.7918, -0.1847], grad_fn=<SliceBackward0>)
nn.Embedding: tensor([ 0.5588,  0.7918, -0.1847], grad_fn=<SliceBackward0>)


### Only used rows get gradients

In [5]:
# make a simple scalar loss: sum of output vectors -> backward
my_emb.zero_grad(set_to_none=True)
pt_emb.zero_grad(set_to_none=True)

# my embedding
out_my = my_emb(example_ids)    # (3, 6)
loss_my = out_my.sum()          # sums over all elements          
loss_my.backward()

print("Non-zero grad rows in MyEmbedding:")
nz_my = (my_emb.weight.grad.abs().sum(dim=1) > 0).nonzero(as_tuple=True)[0]
print(nz_my.tolist())

# PyTorch embedding
out_pt = pt_emb(example_ids)    # (3, 6)
loss_pt = out_pt.sum()          # sums over all elements
loss_pt.backward()

print("Non-zero grad rows in nn.Embedding:")
nz_pt = (pt_emb.weight.grad.abs().sum(dim=1) > 0).nonzero(as_tuple=True)[0]
print(nz_pt.tolist())

# The sets should match the unique IDs we looked up
print("Unique example IDs:", torch.unique(example_ids).tolist())


Non-zero grad rows in MyEmbedding:
[4, 12, 13]
Non-zero grad rows in nn.Embedding:
[4, 12, 13]
Unique example IDs: [4, 12, 13]


## Train a minimal context to next-word model
next-word predictor
- input: average of embedding of a 2-word context
- output: distribution over vocab for the next word

In [11]:
# build context to target pairs (window size = 2, predict next token)
def make_pairs(token_lines, window=2):
    pairs = []
    for line in token_lines:
        # get ids from each line
        ids = [stoi[w] for w in line]
        print(f'ids from curr line: {ids}')
        for i in range(len(ids) - window):
            ctx = ids[i:i+window]   # 2-word context
            tgt = ids[i+window]     # next word
            pairs.append((ctx, tgt))
    return pairs

pairs = make_pairs(tokens, window=2)
print(f'tokens: {tokens}')
print(f'pairs: {pairs}')

ids from curr line: [12, 1, 8, 7, 0, 13, 10]
ids from curr line: [12, 1, 8, 7, 0, 2, 10]
ids from curr line: [12, 1, 9, 3, 13, 4]
ids from curr line: [11, 1, 8, 7, 0, 5, 6]
ids from curr line: [11, 1, 9, 2, 4]
tokens: [['we', 'all', 'live', 'in', 'a', 'yellow', 'submarine'], ['we', 'all', 'live', 'in', 'a', 'blue', 'submarine'], ['we', 'all', 'love', 'bright', 'yellow', 'flowers'], ['they', 'all', 'live', 'in', 'a', 'green', 'house'], ['they', 'all', 'love', 'blue', 'flowers']]
pairs: [([12, 1], 8), ([1, 8], 7), ([8, 7], 0), ([7, 0], 13), ([0, 13], 10), ([12, 1], 8), ([1, 8], 7), ([8, 7], 0), ([7, 0], 2), ([0, 2], 10), ([12, 1], 9), ([1, 9], 3), ([9, 3], 13), ([3, 13], 4), ([11, 1], 8), ([1, 8], 7), ([8, 7], 0), ([7, 0], 5), ([0, 5], 6), ([11, 1], 9), ([1, 9], 2), ([9, 2], 4)]


In [16]:
# simple model = Embedding -> mean -> Linear -> softmax
class CtxAvgNextWord(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.proj = nn.Linear(emb_dim, vocab_size)
    def forward(self, ctx_ids):
        E = self.emb(ctx_ids)
        h = E.mean(dim=1)
        logits = self.proj(h)
        return logits

# initialize model and optimizer
D = 6
model = CtxAvgNextWord(V, D)
opt = torch.optim.Adam(model.parameters(), lr=0.01)

## Training Loop

In [22]:
import random

# make mini-batches from the (context, target) pairs
def batch_iter(pairs, batch__size):
    random.shuffle(pairs)
    for i in range(0, len(pairs), batch_size):
        chunk = pairs[i:i+batch_size]
        ctx = torch.tensor([c for c, t in chunk], dtype=torch.long) # (B, 2)
        tgt = torch.tensor([t for c, t in chunk], dtype=torch.long) # (B,)
        yield ctx, tgt

# train the model
epochs = 200
batch_size = 8
for epoch in range(1, epochs + 1):
    total_loss = 0.0
    n = 0
    for ctx, tgt in batch_iter(pairs, batch_size):
        logits = model(ctx)                 # (B, V)
        loss = F.cross_entropy(logits, tgt) # compares to true target IDs

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item() * ctx.size(0) # for weighted average
        n += ctx.size(0)
    
    if epoch % 50 == 0:
        print(f"Epoch {epoch:3d} | Avg Loss: {total_loss/n:.3f}")

Epoch  50 | Avg Loss: 0.377
Epoch 100 | Avg Loss: 0.373
Epoch 150 | Avg Loss: 0.371
Epoch 200 | Avg Loss: 0.373
