In [1]:
import numpy as np

### Load Data

In [None]:
with open("../data/text8", "rb") as f:
    text = f.read()

tokens = text.split()[:800000]

In [9]:
from collections import Counter

counts = Counter(tokens)

print("Total tokens:", len(tokens))
print("Unique words:", len(counts))

Total tokens: 800000
Unique words: 45149


### Create Vocabulary

In [10]:
vocab_size = 30000
common_words = counts.most_common(vocab_size)
print("Unique words after capping:", len(common_words))

Unique words after capping: 30000


In [18]:
vocab = [w for w, _ in common_words]
freqs = np.array([c for _, c in common_words], dtype=np.float64)

word2id = {w: i for i, w in enumerate(vocab)}
id2word = {i: w for w, i in word2id.items()}

print(word2id)



In [19]:
tokens = [w for w in tokens if w in word2id]
print("Tokens after filtering:", len(tokens))

Tokens after filtering: 784851


In [21]:
ids = np.array([word2id[w] for w in tokens], dtype=np.int32)
print(ids.max())

29999


In [22]:
freqs

array([5.0533e+04, 2.9112e+04, 2.0409e+04, ..., 1.0000e+00, 1.0000e+00,
       1.0000e+00], shape=(30000,))

In [23]:
neg_probs = freqs ** 0.75
neg_probs /= neg_probs.sum()

print("Check sum:", neg_probs.sum())


Check sum: 1.0


In [24]:
neg_probs

array([1.62377561e-02, 1.07373902e-02, 8.22642223e-03, ...,
       4.81775490e-06, 4.81775490e-06, 4.81775490e-06], shape=(30000,))

In [25]:
ids

array([ 505, 3359,   11, ...,    2,    0,  175],
      shape=(784851,), dtype=int32)

### Create training pairs

In [None]:
def make_pairs(ids):
    pairs = []
    window = 5
    n = len(ids)
    for i, center in enumerate(ids):
        left = max(0, i-window)
        right = min(n, i+window+1)
        for j in range(left, right):
            if j!=i:
                pairs.append((center, ids[j]))
    return pairs

In [32]:
window_pairs = make_pairs(ids=ids)
window_pairs[0]

(np.int32(505), np.int32(3359))

In [38]:
len(ids)

784851

In [None]:
D = 100
V = max(ids) + 1
rng = np.random.default_rng(42)

W_in  = (rng.random((V, D)) - 0.5) / D
W_out = np.zeros((V, D), dtype=np.float64)

In [40]:
W_in.shape

(30000, 100)

In [41]:
W_out.shape

(30000, 100)

### Building the network

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def train_step(center_id, pos_id, W_in, W_out, V, neg_probs, rng, K=5, lr=0.025):
    v = W_in[center_id]

    u_pos = W_out[pos_id]

    neg_ids = rng.choice(V, size=K, p=neg_probs)
    u_neg = W_out[neg_ids]

    s_pos = u_pos @ v
    s_neg = u_neg @ v

    loss = -np.log(sigmoid(s_pos) + 1e-10) - np.log(sigmoid(-s_neg) + 1e-10)

    g_pos = sigmoid(s_pos) - 1.0
    g_neg = sigmoid(s_neg)

    grad_v = g_pos * u_pos + (g_neg[:, None] * u_neg).sum(axis=0)
    grad_u_pos = g_pos * v
    grad_u_neg = g_neg[:, None] * v[None, :]

    W_in[center_id] -= lr * grad_v
    W_out[pos_id]   -= lr * grad_u_pos

    for i, nid in enumerate(neg_ids):
        W_out[nid] -= lr * grad_u_neg[i]

    return loss


In [45]:
def train_from_pairs(pairs, W_in, W_out, V, neg_probs, epochs=1, lr=0.025, K=5, log_every=200000):
    rng = np.random.default_rng(42)

    pairs = np.array(pairs, dtype=np.int32)

    step = 0
    for ep in range(1, epochs + 1):
        rng.shuffle(pairs)

        total_loss = 0.0
        for center_id, pos_id in pairs:
            loss = train_step(center_id, pos_id, W_in, W_out, V, neg_probs, rng, K=K, lr=lr)
            total_loss += loss

            step += 1
            if log_every and step % log_every == 0:
                print(f"epoch {ep} step {step} avg_loss={total_loss/step:.4f}")

        print(f"epoch {ep} done | avg_loss={total_loss/len(pairs):.4f}")

### Training network with smaller sample of tokens

In [48]:
ids_train = ids[:200_000]
train_pairs = make_pairs(ids_train)

In [49]:
len(train_pairs)

1999970

In [50]:
train_from_pairs(train_pairs, W_in=W_in, W_out=W_out, V=V, neg_probs=neg_probs)

epoch 1 step 200000 avg_loss=3.9330
epoch 1 step 400000 avg_loss=3.6331
epoch 1 step 600000 avg_loss=3.4395
epoch 1 step 800000 avg_loss=3.3001
epoch 1 step 1000000 avg_loss=3.1948
epoch 1 step 1200000 avg_loss=3.1108
epoch 1 step 1400000 avg_loss=3.0429
epoch 1 step 1600000 avg_loss=2.9863
epoch 1 step 1800000 avg_loss=2.9377
epoch 1 done | avg_loss=2.8958


In [62]:
def nearest(word, W, word2id, id2word, topk=10):
    if word not in word2id:
        return []
    i = word2id[word]
    w = W[i]
    sims = (W @ w) / (np.linalg.norm(W, axis=1) * np.linalg.norm(w) + 1e-10)
    best = np.argsort(-sims)[:topk+1]
    return [(id2word[j], float(sims[j])) for j in best if j != i][:topk]

In [57]:
len(word2id)

30000

In [63]:
print(b"anarchism" in word2id)

True


In [64]:
print(nearest(b"anarchism", W_in, word2id, id2word))

[(b'science', 0.9918110333955067), (b'anthropology', 0.9899102182921868), (b'archaeology', 0.9877576128149408), (b'cultural', 0.9862778180155134), (b'system', 0.9856702299597521), (b'culture', 0.9844848477957222), (b'sometimes', 0.9839632856486511), (b'ancient', 0.9822707788394794), (b'natural', 0.9821781932025976), (b'anarchist', 0.9819834109297017)]
