In [None]:
import numpy as np
from collections import Counter

text = "hope can set you free"
tokens = text.lower().split()

counts = Counter(tokens)
vocab = sorted(counts.keys())
V = len(vocab)
word2id = {w:i for i,w in enumerate(vocab)}
id2word = {i:w for w,i in word2id.items()}

print("vocab:", vocab)
print("V:", V)

vocab: ['can', 'free', 'hope', 'set', 'you']
V: 5


In [5]:
def one_hot(idx, V):
    x = np.zeros((V, 1), dtype=np.float64)
    x[idx, 0] = 1.0
    return x

x = one_hot(word2id["can"], V)
print("one-hot shape:", x.shape)
print(x.T)

one-hot shape: (5, 1)
[[1. 0. 0. 0. 0.]]


In [6]:
D = 3
rng = np.random.default_rng(42)

W  = rng.random((D, V))     # (D, V)
Wp = rng.random((V, D))     # (V, D)

print("W shape:", W.shape)
print("Wp shape:", Wp.shape)

W shape: (3, 5)
Wp shape: (5, 3)


In [7]:
W

array([[0.77395605, 0.43887844, 0.85859792, 0.69736803, 0.09417735],
       [0.97562235, 0.7611397 , 0.78606431, 0.12811363, 0.45038594],
       [0.37079802, 0.92676499, 0.64386512, 0.82276161, 0.4434142 ]])

In [8]:
Wp

array([[0.22723872, 0.55458479, 0.06381726],
       [0.82763117, 0.6316644 , 0.75808774],
       [0.35452597, 0.97069802, 0.89312112],
       [0.7783835 , 0.19463871, 0.466721  ],
       [0.04380377, 0.15428949, 0.68304895]])

In [10]:
h = W @ x
print("h shape:", h.shape)   # (D, 1)
print("h:", h)

h shape: (3, 1)
h: [[0.77395605]
 [0.97562235]
 [0.37079802]]


In [12]:
u = Wp @ h
print("u shape:", u.shape)  # (V, 1)
print("raw scores:", u)

u shape: (5, 1)
raw scores: [[0.74060141]
 [1.53791349]
 [1.55258975]
 [0.96538772]
 [0.43770367]]


In [None]:
def softmax(z):
    z = z - np.max(z)
    expz = np.exp(z)
    return expz / np.sum(expz)

y_hat = softmax(u)
print("y_hat shape:", y_hat.shape)
print("probs sum:", y_hat.sum())

y_hat shape: (5, 1)
probs sum: 1.0


In [14]:
y_hat

array([[0.13400014],
       [0.29742228],
       [0.30181951],
       [0.16777556],
       [0.0989825 ]])

In [15]:
target = "hope"
y = one_hot(word2id[target], V)

loss = -np.sum(y * np.log(y_hat + 1e-10))
print("loss:", loss)


loss: 1.1979260806815235
