# INFO 159/259

# <center> Homework 1: Word Embeddings </center>
<center> Due: February 3, 2026 @ 11:59pm </center>

# HW1: Word Embeddings

In this homework, you will implement _word2vec_ with skip-grams and negative sampling, training on a small slice of Wikipedia data.

*Learning objectives*:
- Understand the implementation details of _word2vec_
- Gain familiarity with `numpy` for matrix math
- Gain familiarity with training a classifier using stochastic gradient descent.

You may want to consult SLP chapter 5 (_Embeddings_) as a reference for the implementation. This homework is designed to run on the CPU only, so if you are using Google Colab, you may want to ensure that your CPU is selected (under `Runtime > Change runtime type` in the top bar) so that you save your GPU allocation for later assignments in the semester.

In [3]:
# download the dataset we will be using
!wget https://github.com/dbamman/nlp-course/raw/refs/heads/main/HW/data/en_wiki_sample.txt -O en_wiki_sample.txt

--2026-02-05 20:17:23--  https://github.com/dbamman/nlp-course/raw/refs/heads/main/HW/data/en_wiki_sample.txt
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/dbamman/nlp-course/refs/heads/main/HW/data/en_wiki_sample.txt [following]
--2026-02-05 20:17:23--  https://raw.githubusercontent.com/dbamman/nlp-course/refs/heads/main/HW/data/en_wiki_sample.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42894174 (41M) [text/plain]
Saving to: ‘en_wiki_sample.txt’


2026-02-05 20:17:24 (282 MB/s) - ‘en_wiki_sample.txt’ saved [42894174/42894174]



In [4]:
import itertools
from collections import Counter

import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from tqdm import tqdm

nltk.download("punkt_tab")

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

## Data loading

We will begin by loading and tokenizing the data. The file contains a list of paragraphs from Wikipeda, separated by newlines. Because each document (a paragraph) is sampled independently, we want to maintain the document boundaries when we sample contexts later.

Inside `FileDataLoader`:
- `idx2vocab` is a list of unique word types
- `vocab2idx` is a dict mapping from a word type to its index in `idx2vocab`
- `word_freqs` is a dict mapping from a word type to its frequency in the corpus

You should implement:
1. The `negative_sample_weights()` function

   This function should calculate the weighted sample probabilities for each of the words in our vocabulary.
   Recall SLP3 eq. 5.19:
    $$
     P_{\alpha}(w) = \frac{\text{count}(w)^{\alpha}}{\sum_{w'}\text{count}(w')^{\alpha}}
    $$
   We calculate and store the sample weights to save time when generating contexts later.
3. The `negative_sample()` function

   This function should sample `num_samples` negative context words given a target word. Recall from SLP3 5.5.2
   > A noise word is a random word from the lexicon, **constrained not to be the target word $w$**. (_emph added_)

   So, when sampling, you will want to copy the original `.sample_weights` numpy array, set the probability of the target word to 0, and renormalize the weights before sampling.

   You may want to consult the numpy documentation for [`numpy.random.Generator.choice()`](https://numpy.org/doc/stable/reference/random/generated/numpy.random.Generator.choice.html#numpy.random.Generator.choice). We have instantiated a random generator for your convenience in `self.rng`.

_Learning objectives_:
> - Understand the implementation details of _word2vec_


In [5]:
corpus_path = "./en_wiki_sample.txt"

In [6]:
class FileDataLoader():
    def __init__(self, filepath, negative_sample_alpha=0.75, min_threshold=5):
        self.negative_sample_alpha = negative_sample_alpha
        self.min_threshold = min_threshold

        self.tokenized_documents = self.load_data(filepath)
        self.word_freqs = self.get_word_freqs(self.tokenized_documents)

        # replace words that appear fewer than min_threshold times with an [UNK] token
        for word, freq in list(self.word_freqs.items()):
            if freq < min_threshold:
                self.word_freqs["[UNK]"] += freq
                del self.word_freqs[word]

        self.idx2vocab = list(self.word_freqs.keys())
        self.vocab2idx = {word: index for index, word in enumerate(self.idx2vocab)}

        # set up a random number generator we can use for sampling
        self.rng = np.random.default_rng(159259)
        self.sample_weights = self.negative_sample_weights(alpha=negative_sample_alpha)

        ...

    def tokenize_and_lowercase(self, doc):
        """Tokenize a doc and lowercase all the words."""
        return [word.lower() for word in word_tokenize(doc)]

    def get_word_freqs(self, tokenized_documents):
        """Return a dictionary mapping each word to its frequency."""
        return Counter(itertools.chain.from_iterable(tokenized_documents))

    def load_data(self, filepath):
        return [self.tokenize_and_lowercase(doc) for doc in tqdm(open(corpus_path, "r").readlines())]

    def negative_sample_weights(self, alpha):
        """Calculate the weighted probabilities of each word.

        Return a (v,)-shaped numpy array, where v is the size of the vocabulary.
        """
        # TODO: implement this function
        counts = np.array([self.word_freqs[word] for word in self.idx2vocab])
        powered = counts ** alpha
        total = powered.sum()
        probs = powered / total
        return probs

    def negative_sample(self, target_word_idx, num_samples):
        """Sample num_samples noise words from the lexicon that is not the target word.

        The sample probabilities should be proportional to their weighted unigram probability if the target word probability is set to 0.

        Return a (num_samples,)-shaped numpy array of sampled indices.
        """
        # TODO: implement this function
        probs = self.sample_weights.copy()
        probs[target_word_idx] = 0.0

        # Renormalize
        total = probs.sum()
        if total > 0:
            probs = probs / total
        else:
            # Fallback: extremely rare case — uniform over all except target
            probs = np.ones_like(probs)
            probs[target_word_idx] = 0.0
            probs = probs / probs.sum()

        samples = self.rng.choice(
            a=len(probs),
            size=num_samples,
            replace=True,
            p=probs
        )
        return samples

    def sample_contexts(self, window_size, sample_k):
        for doc in self.tokenized_documents:
            if len(doc) < (2 * window_size) + 1:
                # the doc is too short for our desired window size; we skip it
                continue
            for word_idx in range(window_size, len(doc) - window_size):
                target_word_idx = self.vocab2idx[doc[word_idx]] if doc[word_idx] in self.vocab2idx else self.vocab2idx["[UNK]"]
                # sample positive words from the window
                positive_word_idxs = np.array([
                    self.vocab2idx[word] if word in self.vocab2idx else self.vocab2idx["[UNK]"] for word in doc[word_idx - window_size:word_idx] + doc[word_idx + 1:word_idx + 1 + window_size]

                ])
                # sample len(positive_word_idxs) * sample_k number of negative words
                negative_word_idxs = self.negative_sample(target_word_idx, sample_k * len(positive_word_idxs))
                yield (target_word_idx, positive_word_idxs, negative_word_idxs)


In [7]:
# this should take roughly 30 seconds
dataloader = FileDataLoader(corpus_path)

100%|██████████| 100000/100000 [00:48<00:00, 2058.17it/s]


**Quick check**: The unweighted probability for "the" should be 0.063; the weighted probability should be 0.016.

In [8]:
print(f"Unweighted probability for `the`: \t\t{dataloader.word_freqs['the'] / dataloader.word_freqs.total():.3f}")
print(f"Weighted (alpha=0.75) probability for `the`: \t{dataloader.sample_weights[dataloader.vocab2idx['the']]:.3f}")

Unweighted probability for `the`: 		0.063
Weighted (alpha=0.75) probability for `the`: 	0.016


## Setting up the model

The word2vec model consists of two matrices: the target (or input) embedding and the context (or output) embedding. We set those up here.

You should implement:
- The `nearest_neighbors()` function

  This given a $d$-dimensional $\vec{v}$ and a $(v \times d)$-dimensional matrix $M$ of vectors to query against, we want to calculate the cosine similarity of $\vec{v}$ with each row of $M$ and return the indices (and the corresponding similarities) of the most similar rows in $M$.

  As a reminder, the cosine similarity of two vectors $\vec{a}$ and $\vec{b}$ is
  $$
    \text{cosine\_sim}(\vec{a}, \vec{b}) = \frac{\vec{a} \cdot \vec{b}}{\|{\vec{a}}\|\|\vec{b}\|}
  $$

  This is derived from one of the formulations for the dot product:
  $$
    \vec{a} \cdot \vec{b} = \|\vec{a}\| \|\vec{b}\| \cos({\theta})
  $$

  $\|\vec{a}\|$ denotes the $l_2$-norm of a vector, or its magnitude.

  You might want to consult the numpy documentation for [`numpy.matmul`](https://numpy.org/doc/2.1/reference/generated/numpy.matmul.html), [`numpy.argsort`](https://numpy.org/doc/2.1/reference/generated/numpy.argsort.html#numpy-argsort), and [`numpy.linalg.norm`](https://numpy.org/doc/2.1/reference/generated/numpy.linalg.norm.html)


_Learning objectives_:
> - Gain familiarity with `numpy` for matrix math


In [9]:
class Word2Vec():
    def __init__(self, dataloader, hidden_dim=100):
        self.dataloader = dataloader
        self.vocab_size = len(self.dataloader.idx2vocab)
        self.hidden_dim = hidden_dim

        np.random.seed(159259)
        # We initialize the model weights to be uniformly randomly distributed and centered around 0.
        self.target_embs = (np.random.random((self.vocab_size, hidden_dim)) - 0.5) / hidden_dim
        self.context_embs = (np.random.random((self.vocab_size, hidden_dim)) - 0.5) / hidden_dim

    def nearest_neighbors(self, query_vector, vectors, n=10):
        """Finds the `n` indices of the rows in `vectors` that have the highest cosine similarity to `query_vector`.

        query_vector: (d,)-shaped numpy array
        vectors: (v, d)-shaped numpy array
        n: int

        Return a tuple of (indices, similarities), where both are (n,)-shaped ndarrays.
        """
        # Normalize query vector
        query_norm = query_vector / (np.linalg.norm(query_vector) + 1e-10)

        # Normalize all rows in vectors (matrix)
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        norms[norms == 0] = 1e-10  # prevent division by zero
        normalized_vectors = vectors / norms

        # Cosine similarity = dot product of normalized vectors
        similarities = np.dot(normalized_vectors, query_norm)

        # Get indices of top n largest similarities (descending order)
        top_indices = np.argsort(similarities)[::-1][:n]
        top_similarities = similarities[top_indices]

        return top_indices, top_similarities

    def print_nearest_neighbors(self, word, n=5):
        """Prints the `n` nearest neighbors for a word using the context embeddings.

        word: str

        Return None
        """
        query_vector = self.context_embs[self.dataloader.vocab2idx[word]]
        closest_inds, similarities = self.nearest_neighbors(query_vector, self.context_embs, n)
        words = [self.dataloader.idx2vocab[ind] for ind in closest_inds]

        print(words)


In [10]:
w2v_model = Word2Vec(dataloader)

**Quick check**: you can check your function against this toy example. The output should be:

- `(array([4, 5, 0, 6, 3]), array([0.91347529, 0.87409283, 0.84518755, 0.83396453, 0.8111933 ]))`

In [11]:
def quick_check():
    np.random.seed(159259)
    query_vec = np.random.random(size=(5,))
    other_vecs = np.random.random(size=(10, 5))
    print(w2v_model.nearest_neighbors(query_vec, other_vecs, n=5))

quick_check()

(array([4, 5, 0, 6, 3]), array([0.91347529, 0.87409283, 0.84518755, 0.83396453, 0.8111933 ]))


**Quick check**: the nearest neighbors for "the" should be random at this point; if you did not edit the `__init__` function, the nearest neighbors should be:

- `['the', 'asian', 'habilitation', 'toward', 'capacity-building']`

In [12]:
w2v_model.print_nearest_neighbors("the")

['the', 'asian', 'habilitation', 'toward', 'capacity-building']


## Setting up the training loop

### Calculating gradients

To update the weights using gradient descent, we have to find the partial derivatives of the loss with respect to the parameters. You can find the loss function and its partial derivatives in SLP 5.5.2 (eqs. 5.22 - 5.24); we've also reproduced them for you below. While we give you the derivatives, it can be a good exercise to try to derive them yourself!

These rely on the sigmoid function, which we've implemented for you as an example.

You should implement:
- `loss_fn`
- `c_pos_grad`
- `c_neg_grad`
- `w_grad`

In each of these functions, you should expect:
- `w` to be a `d`-dimensional vector,
- `c_pos` to be a `(n_pos, d)`-dimensional matrix (where `n_pos` is the number of positive context examples)
- `c_neg` to be a `(n_neg, d)`-dimensional matrix (where `n_neg` is the number of negative context examples)

As a reminder, the sigmoid function is defined as
$$
\sigma(x) = \frac{1}{1 + e^{-x}}
$$

For filling out the rest of the functions, you may want to use [`np.log`](https://numpy.org/devdocs/reference/generated/numpy.log.html#numpy.log), [`np.sum`](https://numpy.org/devdocs/reference/generated/numpy.sum.html), [`np.newaxis`](https://numpy.org/devdocs/reference/constants.html#numpy.newaxis), [`np.matmul`](https://numpy.org/devdocs/reference/generated/numpy.matmul.html#numpy-matmul), and of course, the `sigmoid` function that we have implemented for you.

In [13]:
# we wrap these functions in the @njit decorator to speed up calculations
# using just-in-time compilation
# you don't have to worry about this
from numba import njit

@njit
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [14]:
@njit
def loss_fn(w, c_pos, c_neg):
    pos_scores = c_pos @ w
    neg_scores = c_neg @ w

    # small epsilon prevents log(0) when sigmoid ≈ 1
    pos_loss = -np.log(sigmoid(pos_scores) + 1e-12)
    neg_loss = -np.log(1.0 - sigmoid(neg_scores) + 1e-12)

    return np.sum(pos_loss) + np.sum(neg_loss)

In [15]:
@njit
def c_pos_grad(w, c_pos):
    # Gradient w.r.t. positive context embeddings
    # Shape: (n_pos, d)
    scores = c_pos @ w                  # (n_pos,)
    sig = sigmoid(scores)
    delta = sig - 1.0                   # (n_pos,)

    return delta[:, np.newaxis] * w     # (n_pos, d)

In [16]:
@njit
def c_neg_grad(w, c_neg):
   # Gradient w.r.t. negative context embeddings
    # Shape: (n_neg, d)
    scores = c_neg @ w
    sig = sigmoid(scores)

    return sig[:, np.newaxis] * w       # (n_neg, d)

In [17]:
@njit
def w_grad(w, c_pos, c_neg):
    pos_scores = c_pos @ w
    neg_scores = c_neg @ w

    pos_sig = sigmoid(pos_scores)
    neg_sig = sigmoid(neg_scores)

    grad_pos = c_pos.T @ (pos_sig - 1.0)   # (d,)
    grad_neg = c_neg.T @ neg_sig           # (d,)

    return grad_pos + grad_neg

**(Not so) Quick check**: We can check the correctness of the loss function and gradient calculations by numerically approximating the gradients using neighboring points and seeing if they match up. Recall from your calculus class:

$$
\frac{d}{dx} f(x) = \lim_{h \to 0} \frac{f(x + h) - f(x - h)}{2h}
$$

We implement this in the `approximate_gradient` function so that we can estimate the local gradient and see if the closed-form solution that you implemented in the functions above are accurate. However, we never numerically approximate the gradient during training because we have a closed-form solution that is both more accurate and more efficient to calculate.

> **Aside**: In this assignment, we have you manually calculate the loss and gradients. If you have taken other deep learning classes, you may have experience with libraries like Pytorch, which implement automatic differentiation so that you can just specify the loss function and not have to work out the gradients manually.
>
> These libraries _don't_ use numerical approximation for the gradients. Instead, they rely on the chain rule:
>
> $$
    \frac{d}{dx} f(g(x)) = f'(g(x)) g'(x)
  $$
> As long as all of the functions you apply to an input are differentiable, and the closed-form derivatives are known (which they often are, since most functions break down into basic differentiable operations like addition, multiplication, or exponentiation), the library can construct a graph to track all of the applications of the functions and calculate the partial derivatives using this graph.\
>
> You can read more about this in the [Pytorch autograd tutorial](https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph).

Your loss should be roughly 8.05; if it is not, all of the assertions in the `quick_check` will likely fail even if (especially if) your gradients are implemented correctly.

In [18]:
def quick_check():
    np.random.seed(159259)

    w = np.random.random((5,))
    c_pos = np.random.random((2, 5))
    c_neg = np.random.random((4, 5))

    eps = 1e-5

    def approximate_gradient(func, vec, eps=1e-5):
        est_grad = np.zeros(vec.shape)
        for ind, el in np.ndenumerate(vec):
            perturb = np.zeros(vec.shape)
            perturb[ind] = eps
            est_grad[ind] = (func(vec + perturb) - func(vec - perturb)) / (2 * eps)
        return est_grad

    print("loss:", loss_fn(w, c_pos, c_neg))

    assert np.allclose(w_grad(w, c_pos, c_neg), approximate_gradient(lambda x: loss_fn(x, c_pos, c_neg), w)), "c_pos_grad is not correct for loss_fn"
    assert np.allclose(c_pos_grad(w, c_pos), approximate_gradient(lambda x: loss_fn(w, x, c_neg), c_pos)), "c_pos_grad is not correct for loss_fn"
    assert np.allclose(c_neg_grad(w, c_neg), approximate_gradient(lambda x: loss_fn(w, c_pos, x), c_neg)), "c_neg_grad is not correct for loss_fn"

quick_check()

loss: 8.052986383589216


### Updating weights in the training loop

The training loop for SGD consists of sampling one instance of the data (in our case, a target word and its positive and negative contexts), and calculating the partial derivatives of the loss.

We then update the parameters using these partial derivatives, multiplying each gradient by the learning rate $\eta$. When we perform gradient descent, we subtract the gradients from the weights in order to shift the weights in a direction that decreases the loss (locally, at least). Here are the updates we make:
$$
c_{\text{pos}}^{t + 1} = c_{\text{pos}}^{t} - \eta \frac{\partial L}{\partial {c}_{\text{pos}}^t},
$$
$$
c_{\text{neg}}^{t + 1} = c_{\text{neg}}^{t} - \eta \frac{\partial L}{\partial {c}_{\text{neg}}^t}
,$$
$$
w^{t + 1} = w^{t} - \eta \frac{\partial L}{\partial w^t}
,$$
where $t + 1$ is the next timestep in the stochastic gradient descent loop.

**Note**: We print some diagnostic information, including the loss, to help you monitor the training. You should convince yourself that, though we calculate the loss and print it here to track our training, SGD doesn't actually require that we compute the loss as such; we really only need the gradients.

You implement:
- the section of the code where you calculate the gradients
- the section of the code where you use the gradients to update the embedding

You may want to read about [numpy indexing](https://numpy.org/doc/2.2/user/basics.indexing.html#), since the `.sample_contexts()` returns lists of indices; you might also want to look into [`np.subtract.at()`](https://numpy.org/doc/2.2/reference/generated/numpy.ufunc.at.html) (see the usage of `np.add.at()` in the starter code as another example).

With a learning rate of 0.01, you should see some nearest neighbors start to make sense after about the loss drops under 60 or so. This took around 60K steps and 1m21s on our solution code; we recommend running for at least 10 minutes.

_Learning objectives_:
> - Gain familiarity with training a classifier using stochastic gradient descent.


In [19]:
NUM_EPOCHS = 1
LEARNING_RATE = 0.01

def train(model, dataloader):

    num_target_updates = np.zeros((model.target_embs.shape[0],))
    num_context_updates = np.zeros((model.context_embs.shape[0],))

    def print_diagnostic(word):
        print(f"`{word}` was updated {int(num_target_updates[dataloader.vocab2idx[word]])} times in target and {int(num_context_updates[dataloader.vocab2idx[word]])} times in context")
        model.print_nearest_neighbors(word, 4)

    for i in range(NUM_EPOCHS):
        losses = []
        for i, (target, pos, neg) in enumerate(tqdm(dataloader.sample_contexts(window_size=2, sample_k=100))):

            if i % 10_000 == 0:
                # Print diagnostic info every 10_000 steps.
                print("avg loss:", sum(losses) / len(losses) if losses else "")
                losses = []
                print_diagnostic("he")
                print_diagnostic("original")
                print_diagnostic("january")

            # Get the vectors from the model
            w = model.target_embs[target]
            c_pos = model.context_embs[pos]
            c_neg = model.context_embs[neg]

            # Calculate and store the loss
            losses.append(loss_fn(w, c_pos, c_neg))

            # TODO: Calculate the gradients and implement the gradient update.
            # Calculate gradients
            grad_w = w_grad(w, c_pos, c_neg)          # (d,)
            grad_c_pos = c_pos_grad(w, c_pos)         # (n_pos, d)
            grad_c_neg = c_neg_grad(w, c_neg)         # (n_neg, d)

            # Update target embedding
            model.target_embs[target] -= LEARNING_RATE * grad_w

            # Update positive context embeddings
            np.subtract.at(model.context_embs, pos, LEARNING_RATE * grad_c_pos)

            # Update negative context embeddings
            np.subtract.at(model.context_embs, neg, LEARNING_RATE * grad_c_neg)

            # Tally up how many times each word has been seen, just for fun.
            np.add.at(num_target_updates, target, 1)
            np.add.at(num_context_updates, pos, 1)
            np.add.at(num_context_updates, neg, 1)

w2v_model = Word2Vec(dataloader)
model = w2v_model
train(w2v_model, dataloader)

1it [00:00,  5.08it/s]

avg loss: 
`he` was updated 0 times in target and 0 times in context
['he', 'transponders', 'dusty', 'lse']
`original` was updated 0 times in target and 0 times in context
['original', 'quivering', 'saltire', 'bae']
`january` was updated 0 times in target and 0 times in context
['january', 'dailey', 'neutron', 'apogee']


10025it [00:19, 423.03it/s]

avg loss: 168.7864973860151
`he` was updated 62 times in target and 10654 times in context
['he', 'with', 'is', 'at']
`original` was updated 0 times in target and 1042 times in context
['original', 'over', 'can', 'no']
`january` was updated 5 times in target and 1656 times in context
['january', ';', 'not', 'two']


20001it [00:38, 396.95it/s]

avg loss: 99.46171313075446
`he` was updated 115 times in target and 21319 times in context
['he', 'his', 'is', 'for']
`original` was updated 3 times in target and 2099 times in context
['original', 'under', 'i', 'them']
`january` was updated 9 times in target and 3354 times in context
['january', 'house', '%', 'through']


30008it [00:57, 233.13it/s]

avg loss: 80.74189581539237
`he` was updated 167 times in target and 31915 times in context
['he', 'is', 'that', 'has']
`original` was updated 4 times in target and 3109 times in context
['original', 'great', 'film', ']']
`january` was updated 18 times in target and 4958 times in context
['january', '2010', 'there', 'war']


40003it [01:15, 444.42it/s]

avg loss: 66.31472379247688
`he` was updated 240 times in target and 42555 times in context
['he', 'his', 'which', '%']
`original` was updated 7 times in target and 4142 times in context
['original', 'final', 'black', 'language']
`january` was updated 25 times in target and 6531 times in context
['january', 'may', 'but', 'which']


50022it [01:33, 429.36it/s]

avg loss: 63.57901774861095
`he` was updated 295 times in target and 53255 times in context
['he', 'was', 'his', 'its']
`original` was updated 10 times in target and 5189 times in context
['original', 'role', 'society', 'east']
`january` was updated 29 times in target and 8154 times in context
['january', 'november', '2019', 'even']


60004it [01:53, 436.58it/s]

avg loss: 59.129581633381804
`he` was updated 358 times in target and 63866 times in context
['he', 'it', 'his', 'was']
`original` was updated 11 times in target and 6230 times in context
['original', 'central', 'society', 'process']
`january` was updated 32 times in target and 9829 times in context
['january', 'december', 'april', '2021']


69976it [02:11, 616.25it/s]

avg loss: 53.30099826618705
`he` was updated 423 times in target and 74495 times in context
['he', 'it', 'which', 'this']
`original` was updated 13 times in target and 7246 times in context
['original', 'central', 'football', 'study']
`january` was updated 38 times in target and 11503 times in context
['january', 'april', 'may', 'september']


80002it [02:30, 223.17it/s]

avg loss: 49.564571632336936
`he` was updated 472 times in target and 84873 times in context
['he', 'it', 'also', 'this']
`original` was updated 15 times in target and 8362 times in context
['original', 'division', 'act', 'make']
`january` was updated 43 times in target and 13165 times in context
['january', 'later', 'years', 'these']


90023it [02:49, 414.15it/s]

avg loss: 45.449508167910146
`he` was updated 524 times in target and 95370 times in context
['he', 'it', 'she', 'this']
`original` was updated 22 times in target and 9377 times in context
['original', 'royal', 'population', 'order']
`january` was updated 49 times in target and 14790 times in context
['january', 'december', 'october', 'july']


99999it [03:07, 608.21it/s]

avg loss: 45.56065816814211
`he` was updated 581 times in target and 105966 times in context
['he', 'it', 'this', 'she']
`original` was updated 26 times in target and 10471 times in context
['original', 'end', 'army', 'most']
`january` was updated 52 times in target and 16370 times in context
['january', 'september', 'november', 'february']


110000it [03:26, 600.41it/s]

avg loss: 44.76511420894129
`he` was updated 617 times in target and 116489 times in context
['he', 'she', 'it', 'also']
`original` was updated 31 times in target and 11506 times in context
['original', 'support', 'society', 'head']
`january` was updated 59 times in target and 17951 times in context
['january', 'december', 'may', 'november']


119997it [03:44, 607.85it/s]

avg loss: 43.18814245860986
`he` was updated 682 times in target and 127176 times in context
['he', 'it', 'they', 'also']
`original` was updated 34 times in target and 12551 times in context
['original', 'role', 'end', 'construction']
`january` was updated 61 times in target and 19555 times in context
['january', 'october', 'april', 'while']


129996it [04:04, 419.62it/s]

avg loss: 40.15894477202373
`he` was updated 747 times in target and 137792 times in context
['he', 'it', 'was', 'she']
`original` was updated 36 times in target and 13617 times in context
['original', 'side', 'east', 'head']
`january` was updated 70 times in target and 21162 times in context
['january', 'september', 'november', 'december']


139987it [04:22, 606.45it/s]

avg loss: 39.36942717098174
`he` was updated 805 times in target and 148711 times in context
['he', 'it', 'they', 'she']
`original` was updated 38 times in target and 14652 times in context
['original', 'front', 'goal', 'marriage']
`january` was updated 71 times in target and 22787 times in context
['january', 'september', 'october', 'november']


149993it [04:40, 614.07it/s]

avg loss: 38.31133240433373
`he` was updated 852 times in target and 159185 times in context
['he', 'it', 'also', 'was']
`original` was updated 40 times in target and 15678 times in context
['original', 'side', 'role', 'front']
`january` was updated 79 times in target and 24456 times in context
['january', 'december', 'september', 'april']


159998it [05:00, 606.35it/s]

avg loss: 38.095661517521336
`he` was updated 910 times in target and 169813 times in context
['he', 'she', 'it', 'also']
`original` was updated 42 times in target and 16779 times in context
['original', 'royal', 'role', 'start']
`january` was updated 87 times in target and 26137 times in context
['january', 'september', 'december', 'october']


169977it [05:18, 622.39it/s]

avg loss: 38.328657117761914
`he` was updated 969 times in target and 180404 times in context
['he', 'it', 'she', 'also']
`original` was updated 45 times in target and 17785 times in context
['original', 'latter', 'role', 'division']
`january` was updated 95 times in target and 27685 times in context
['january', 'august', 'december', 'october']


179991it [05:38, 565.45it/s]

avg loss: 36.58525935536353
`he` was updated 1006 times in target and 190863 times in context
['he', 'it', 'she', 'they']
`original` was updated 45 times in target and 18816 times in context
['original', 'southern', 'finish', 'next']
`january` was updated 103 times in target and 29265 times in context
['january', 'november', 'august', 'july']


189969it [05:55, 619.03it/s]

avg loss: 36.553829205456815
`he` was updated 1057 times in target and 201365 times in context
['he', 'it', 'she', 'they']
`original` was updated 46 times in target and 19872 times in context
['original', 'society', 'commission', 'administrative']
`january` was updated 105 times in target and 30867 times in context
['january', 'december', 'october', 'september']


199961it [06:14, 428.39it/s]

avg loss: 35.4696066829595
`he` was updated 1107 times in target and 211747 times in context
['he', 'it', 'she', 'they']
`original` was updated 46 times in target and 20937 times in context
['original', 'primary', 'police', 'administration']
`january` was updated 110 times in target and 32460 times in context


200042it [06:14, 249.73it/s]

['january', 'september', 'december', 'october']


210023it [06:33, 432.70it/s]

avg loss: 34.65169144193829
`he` was updated 1183 times in target and 222411 times in context
['he', 'it', 'she', 'after']
`original` was updated 47 times in target and 21975 times in context
['original', 'tour', 'police', 'primary']
`january` was updated 118 times in target and 34084 times in context
['january', 'september', 'december', 'october']


220009it [06:51, 442.65it/s]

avg loss: 34.42011707709473
`he` was updated 1229 times in target and 232747 times in context
['he', 'it', 'she', 'after']
`original` was updated 52 times in target and 23032 times in context
['original', 'crew', 'police', 'band']
`january` was updated 121 times in target and 35710 times in context
['january', 'september', 'december', 'october']


229971it [07:11, 602.66it/s]

avg loss: 33.491829121224725
`he` was updated 1290 times in target and 243328 times in context
['he', 'it', 'she', 'after']
`original` was updated 58 times in target and 24122 times in context
['original', 'hospital', 'band', 'fourth']
`january` was updated 127 times in target and 37316 times in context
['january', 'september', 'december', 'october']


240001it [07:29, 447.11it/s]

avg loss: 34.30835913276074
`he` was updated 1352 times in target and 253904 times in context
['he', 'it', 'she', 'there']
`original` was updated 59 times in target and 25143 times in context
['original', 'process', 'fourth', 'match']
`january` was updated 131 times in target and 38882 times in context
['january', 'september', 'december', 'november']


249987it [07:47, 340.65it/s]

avg loss: 32.52467289011175
`he` was updated 1404 times in target and 264535 times in context
['he', 'it', 'she', 'they']
`original` was updated 62 times in target and 26207 times in context
['original', 'police', 'fourth', 'straight']
`january` was updated 133 times in target and 40482 times in context
['january', 'december', 'april', 'march']


260003it [08:07, 449.76it/s]

avg loss: 33.53477252233783
`he` was updated 1466 times in target and 275224 times in context
['he', 'she', 'it', 'they']
`original` was updated 68 times in target and 27254 times in context
['original', 'previous', 'construction', 'country']
`january` was updated 137 times in target and 42122 times in context
['january', 'september', 'july', 'december']


269978it [08:24, 617.63it/s]

avg loss: 31.533820095626037
`he` was updated 1522 times in target and 285724 times in context
['he', 'she', 'it', 'they']
`original` was updated 71 times in target and 28343 times in context
['original', 'largest', 'construction', 'laws']
`january` was updated 140 times in target and 43698 times in context
['january', 'september', 'december', 'august']


279978it [08:44, 619.89it/s]

avg loss: 31.996058609043963
`he` was updated 1564 times in target and 296331 times in context
['he', 'it', 'she', 'they']
`original` was updated 72 times in target and 29390 times in context
['original', 'british', 'construction', 'catholic']
`january` was updated 143 times in target and 45331 times in context
['january', 'december', 'november', 'july']


289989it [09:02, 621.63it/s]

avg loss: 31.57539971674375
`he` was updated 1616 times in target and 307198 times in context
['he', 'she', 'it', 'they']
`original` was updated 76 times in target and 30389 times in context
['original', 'turn', 'countries', 'western']
`january` was updated 148 times in target and 46976 times in context
['january', 'july', 'august', 'february']


300001it [09:24, 154.10it/s]

avg loss: 30.71744563332779
`he` was updated 1687 times in target and 317862 times in context
['he', 'she', 'it', 'they']
`original` was updated 77 times in target and 31462 times in context
['original', 'construction', 'administration', '1990s']
`january` was updated 151 times in target and 48556 times in context
['january', 'july', 'december', 'february']


309987it [09:43, 614.61it/s]

avg loss: 30.65385732902176
`he` was updated 1747 times in target and 328386 times in context
['he', 'she', 'it', 'they']
`original` was updated 82 times in target and 32524 times in context
['original', 'fourth', 'primary', 'largest']
`january` was updated 156 times in target and 50163 times in context
['january', 'august', 'november', 'december']


319996it [10:02, 331.78it/s]

avg loss: 29.505529299014878
`he` was updated 1808 times in target and 338859 times in context
['he', 'she', 'it', 'they']
`original` was updated 85 times in target and 33617 times in context
['original', 'fourth', 'largest', 'pacific']
`january` was updated 164 times in target and 51847 times in context


320030it [10:02, 192.93it/s]

['january', 'december', 'february', 'july']


330016it [10:21, 426.12it/s]

avg loss: 30.7709866916364
`he` was updated 1841 times in target and 349312 times in context
['he', 'she', 'it', 'they']
`original` was updated 87 times in target and 34589 times in context
['original', 'fourth', 'japanese', 'largest']
`january` was updated 166 times in target and 53443 times in context
['january', 'november', 'december', 'february']


339994it [10:39, 615.61it/s]

avg loss: 29.329959593719785
`he` was updated 1920 times in target and 360077 times in context
['he', 'she', 'they', 'it']
`original` was updated 91 times in target and 35644 times in context
['original', 'parliament', 'sea', 'nations']
`january` was updated 169 times in target and 55032 times in context
['january', 'august', 'november', 'july']


350016it [10:59, 422.13it/s]

avg loss: 29.73563656224657
`he` was updated 1972 times in target and 370536 times in context
['he', 'she', 'it', 'they']
`original` was updated 93 times in target and 36651 times in context
['original', 'security', 'headquarters', 'fourth']
`january` was updated 173 times in target and 56702 times in context
['january', 'december', 'november', 'august']


359980it [11:17, 619.30it/s]

avg loss: 29.3259635888434
`he` was updated 2037 times in target and 381336 times in context
['he', 'she', 'it', 'they']
`original` was updated 94 times in target and 37693 times in context
['original', 'security', 'largest', 'sea']
`january` was updated 177 times in target and 58330 times in context
['january', 'december', 'november', 'august']


369988it [11:36, 567.18it/s]

avg loss: 28.879276263005707
`he` was updated 2111 times in target and 391925 times in context
['he', 'she', 'they', 'it']
`original` was updated 95 times in target and 38717 times in context
['original', 'security', 'current', 'primary']
`january` was updated 183 times in target and 60011 times in context
['january', 'december', 'july', 'april']


379980it [11:54, 617.08it/s]

avg loss: 30.0203139680596
`he` was updated 2156 times in target and 402550 times in context
['he', 'she', 'they', 'it']
`original` was updated 98 times in target and 39801 times in context
['original', 'sea', 'current', 'assembly']
`january` was updated 185 times in target and 61598 times in context
['january', 'november', 'july', 'september']


389956it [12:13, 485.44it/s]

avg loss: 28.45990812292521
`he` was updated 2224 times in target and 413177 times in context
['he', 'she', 'they', 'it']
`original` was updated 99 times in target and 40871 times in context
['original', 'official', 'construction', 'office']
`january` was updated 191 times in target and 63241 times in context


390050it [12:13, 297.68it/s]

['january', 'december', 'july', 'november']


399982it [12:32, 613.85it/s]

avg loss: 29.050891767237353
`he` was updated 2277 times in target and 423935 times in context
['he', 'she', 'they', 'it']
`original` was updated 102 times in target and 41939 times in context
['original', 'official', 'sea', 'current']
`january` was updated 198 times in target and 64831 times in context
['january', 'december', 'september', 'august']


410005it [12:50, 439.96it/s]

avg loss: 28.58369967786614
`he` was updated 2341 times in target and 434423 times in context
['he', 'she', 'they', 'it']
`original` was updated 105 times in target and 42970 times in context
['original', 'current', 'official', 'round']
`january` was updated 201 times in target and 66471 times in context
['january', 'november', 'december', 'september']


420028it [13:10, 420.69it/s]

avg loss: 28.25948376480802
`he` was updated 2396 times in target and 445081 times in context
['he', 'she', 'they', 'it']
`original` was updated 110 times in target and 44081 times in context
['original', 'official', 'current', 'head']
`january` was updated 204 times in target and 68132 times in context
['january', 'december', 'july', 'february']


430011it [13:28, 429.08it/s]

avg loss: 29.123622156081314
`he` was updated 2450 times in target and 455525 times in context
['he', 'she', 'it', 'they']
`original` was updated 117 times in target and 45138 times in context
['original', 'current', 'sea', 'side']
`january` was updated 208 times in target and 69777 times in context
['january', 'december', 'march', 'july']


439991it [13:46, 366.88it/s]

avg loss: 27.66896580756948
`he` was updated 2516 times in target and 466165 times in context
['he', 'she', 'it', 'they']
`original` was updated 118 times in target and 46141 times in context
['original', 'council', 'official', 'largest']
`january` was updated 210 times in target and 71367 times in context


440029it [13:47, 206.51it/s]

['january', 'december', 'november', 'april']


450015it [14:06, 422.89it/s]

avg loss: 28.035358396950482
`he` was updated 2565 times in target and 476758 times in context
['he', 'she', 'it', 'they']
`original` was updated 120 times in target and 47179 times in context
['original', 'current', 'official', 'entire']
`january` was updated 214 times in target and 72959 times in context
['january', 'december', 'august', 'november']


460000it [14:24, 604.99it/s]

avg loss: 28.1282208889428
`he` was updated 2624 times in target and 487344 times in context
['he', 'she', 'it', 'they']
`original` was updated 122 times in target and 48195 times in context
['original', 'official', 'entire', 'security']
`january` was updated 217 times in target and 74602 times in context
['january', 'december', 'november', 'march']


469974it [14:43, 609.33it/s]

avg loss: 27.460762006894534
`he` was updated 2665 times in target and 497842 times in context
['he', 'she', 'it', 'they']
`original` was updated 125 times in target and 49220 times in context
['original', 'official', 'entire', 'largest']
`january` was updated 218 times in target and 76245 times in context
['january', 'september', 'march', 'june']


474117it [14:50, 532.33it/s]


KeyboardInterrupt: 

Once you are satisfied with the training (you can stop it whenever you want), experiment with printing out some nearest neighbors. Do these align with your expectations? Do any surprise you?

In [20]:
model.print_nearest_neighbors("paris", 4)

['paris', '1952', 'cash', '1979']


In [21]:
model.print_nearest_neighbors("january", 4)

['january', 'december', 'august', 'july']


## Submission

Congratulations on finishing HW1!

Please ensure that you submit a PDF of this notebook onto [Gradescope](https://www.gradescope.com/courses/1238346) before February 3 at 11:59pm.

You can run the cell below to generate a PDF if you are using Google Colab.

In [None]:
#EXPORT_EXCLUDE#

#@markdown This is a helper function to generate a PDF in Colab.
#@markdown If you are using Jupyter notebook, you can do `File > Save and Export Notebook as HTML`, then save the resulting HTML file as a PDF.
#@markdown Alternatively, in Juypter notebook, you might try `File > Save and Export Notebook as PDF`, but just make sure you already have `pandoc` installed.

def colab_export_pdf():
    # Modified from: https://medium.com/@jonathanagustin/convert-colab-notebook-to-pdf-0ccd8f847dd6
    try:
        import google.colab
        IN_COLAB = True
    except:
        IN_COLAB = False
        print("This cell only works in Google Colab!")
        print("If you are running locally, click File > Export as HTML. Then open the HTML file and save it as a PDF.")

    if IN_COLAB:
        print("Generating PDF. This may take a few seconds.")
        import os, datetime, json, locale, pathlib, urllib, requests, werkzeug, nbformat, google, yaml, warnings
        locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
        NAME = pathlib.Path(werkzeug.utils.secure_filename(urllib.parse.unquote(requests.get(f"http://{os.environ['COLAB_JUPYTER_IP']}:{os.environ['KMP_TARGET_PORT']}/api/sessions").json()[0]["name"])))
        TEMP = pathlib.Path("/content/pdfs") / f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_{NAME.stem}"; TEMP.mkdir(parents=True, exist_ok=True)
        NB = [cell for cell in nbformat.reads(json.dumps(google.colab._message.blocking_request("get_ipynb", timeout_sec=30)["ipynb"]), as_version=4).cells if "--Colab2PDF" not in cell.source]
        warnings.filterwarnings('ignore', category=nbformat.validator.MissingIDFieldWarning)
        with (TEMP / f"{NAME.stem}.ipynb").open("w", encoding="utf-8") as nb_copy: nbformat.write(nbformat.v4.new_notebook(cells=NB or [nbformat.v4.new_code_cell("#")]), nb_copy)
        if not pathlib.Path("/usr/local/bin/quarto").exists():
            !wget -q "https://quarto.org/download/latest/quarto-linux-amd64.deb" -P {TEMP} && dpkg -i {TEMP}/quarto-linux-amd64.deb > /dev/null && quarto install tinytex --update-path --quiet
        with (TEMP / "config.yml").open("w", encoding="utf-8") as file: yaml.dump({'include-in-header': [{"text": r"\usepackage{fvextra}\DefineVerbatimEnvironment{Highlighting}{Verbatim}{breaksymbolleft={},showspaces=false,showtabs=false,breaklines,breakanywhere,commandchars=\\\{\}}"}],'include-before-body': [{"text": r"\DefineVerbatimEnvironment{verbatim}{Verbatim}{breaksymbolleft={},showspaces=false,showtabs=false,breaklines}"}]}, file)
        !quarto render {TEMP}/{NAME.stem}.ipynb --metadata-file={TEMP}/config.yml --to pdf -M latex-auto-install -M margin-top=1in -M margin-bottom=1in -M margin-left=1in -M margin-right=1in --quiet
        google.colab.files.download(str(TEMP / f"{NAME.stem}.pdf"))

colab_export_pdf()

Generating PDF. This may take a few seconds.
