In [1]:
import numpy as np
import re

# sample text
sentences = "The sun rises in the east and sets in the west"

# remove special characters
sentences = re.sub(r'[^a-zA-Z]', ' ', sentences)

# lowercase and split into words
sentences = sentences.lower()
words = sentences.split()

print("Words:", words)
vocab = set(words)
vocab_size = len(vocab)
embed_dim = 10      # each word will be represented by 10 numbers
context_size = 2    # two words before and after as context

print("Vocab size:", vocab_size)
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for i, word in enumerate(vocab)}
data = []
for i in range(2, len(words) - 2):
    context = [words[i-2], words[i-1], words[i+1], words[i+2]]
    target = words[i]
    data.append((context, target))

print("Sample training pair:", data[0])
embeddings = np.random.random_sample((vocab_size, embed_dim))
def linear(x, theta):
    return np.dot(x, theta)        # simple matrix multiplication

def log_softmax(x):
    x = x - np.max(x)              # for numerical stability
    return np.log(np.exp(x) / np.exp(x).sum())
learning_rate = 0.01
theta = np.random.randn(embed_dim, vocab_size)

for epoch in range(20):
    total_loss = 0
    for context, target in data:
        # Average embedding of context words
        context_vecs = np.mean([embeddings[word_to_ix[w]] for w in context], axis=0)

        # Predict target word
        out = linear(context_vecs, theta)
        logs = log_softmax(out.reshape(1, -1))

        # Simplified “loss”: take negative of target word log prob
        target_index = word_to_ix[target]
        loss = -logs[0][target_index]
        total_loss += loss

    if epoch % 5 == 0:
        print(f"Epoch {epoch}, Loss = {total_loss:.4f}")
for word in list(vocab)[:5]:
    print(word, ":", embeddings[word_to_ix[word]])


Words: ['the', 'sun', 'rises', 'in', 'the', 'east', 'and', 'sets', 'in', 'the', 'west']
Vocab size: 8
Sample training pair: (['the', 'sun', 'in', 'the'], 'rises')
Epoch 0, Loss = 19.6587
Epoch 5, Loss = 19.6587
Epoch 10, Loss = 19.6587
Epoch 15, Loss = 19.6587
the : [0.00112608 0.74159383 0.99206351 0.08542322 0.28640219 0.15411002
 0.5929352  0.43046107 0.23301669 0.48032745]
in : [0.59054093 0.30682978 0.35826303 0.18887838 0.20505426 0.92342662
 0.93151653 0.45124081 0.70979193 0.92289672]
sets : [0.91378971 0.97253372 0.16767662 0.53792175 0.90431207 0.98299663
 0.36514551 0.45458731 0.07404879 0.54000758]
west : [0.11169505 0.2667019  0.61038821 0.28905056 0.29913764 0.08383021
 0.4817099  0.2213171  0.80479563 0.75457591]
sun : [0.49283255 0.0923667  0.47975027 0.63819327 0.3896817  0.34262669
 0.20226906 0.66552097 0.57952104 0.19869312]
