# Student model training & retrieval

This notebook trains the `StudentModel` (text-only baseline) on `dataset.json`, implements a retrieval function that returns the closest words by cosine similarity given a definition, and computes recall@1/5/10 on a held-out test subset.

Notes:
- Make sure required packages are installed (`numpy`, `torch`, `transformers`, `tqdm`).
- Run the cells sequentially.

In [1]:
# Check environment and show install suggestions
try:
    import torch, numpy as np
    import transformers
    print('Found: torch, numpy, transformers')
except Exception as e:
    print('Missing packages. Install with:')
    print('  python3 -m pip install numpy torch transformers tqdm')
    raise


  from .autonotebook import tqdm as notebook_tqdm


Found: torch, numpy, transformers


In [2]:
# Imports and helpers
import torch
import numpy as np
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm.auto import tqdm

from student import (
    load_glove_embeddings,
    TextOnlyDataset,
    StudentModel,
    collate_examples,
    train_one_epoch,
    evaluate,
)

print('Imports OK')


Imports OK


In [3]:
# Config
DATASET = 'dataset.json'
GLOVE = './glove.6B.300d.txt'
BERT = 'bert-base-uncased'
BATCH = 16
EPOCHS = 20
LR = 2e-5
MAXLEN = 128
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print('Device:', DEVICE)


Device: mps


In [4]:
# Load GloVe and dataset
print('Loading GloVe...')
glove = load_glove_embeddings(GLOVE)
print('GloVe dim =', next(iter(glove.values())).shape[0])

print('Loading dataset...')
ds = TextOnlyDataset(DATASET, glove, tokenizer_name=BERT, max_length=MAXLEN)

n = len(ds)
print('Dataset size:', n)

# 90/10 split for train/test
train_n = int(0.9 * n)
indices = list(range(n))
train_idx = indices[:train_n]
test_idx = indices[train_n:]

train_ds = torch.utils.data.Subset(ds, train_idx)
test_ds = torch.utils.data.Subset(ds, test_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, collate_fn=lambda x: collate_examples(x))
val_loader = DataLoader(test_ds, batch_size=BATCH, shuffle=False, collate_fn=lambda x: collate_examples(x))

print('Train size:', len(train_ds), 'Test size:', len(test_ds))


Loading GloVe...
GloVe dim = 300
Loading dataset...
GloVe dim = 300
Loading dataset...
Dataset size: 971
Train size: 873 Test size: 98
Dataset size: 971
Train size: 873 Test size: 98


In [5]:
# Build model, optimizer, criterion
model = StudentModel(bert_model_name=BERT, target_dim=next(iter(glove.values())).shape[0])
model.to(DEVICE)
optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LR)
criterion = torch.nn.MSELoss()

print('Model created; training will run for', EPOCHS, 'epochs')


Model created; training will run for 20 epochs


In [6]:
# Training loop
for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, DEVICE, criterion)
    val_loss = evaluate(model, val_loader, DEVICE, criterion)
    print(f"Epoch {epoch}: train_loss={train_loss:.6f} val_loss={val_loss:.6f}")

# Optionally save the model
model_path = Path('./ckpts/student_model.pt')
torch.save(model.state_dict(), model_path)
print('Saved model to', model_path)


Epoch 1: train_loss=0.145486 val_loss=0.136229
Epoch 2: train_loss=0.125500 val_loss=0.132963
Epoch 2: train_loss=0.125500 val_loss=0.132963
Epoch 3: train_loss=0.123059 val_loss=0.133009
Epoch 3: train_loss=0.123059 val_loss=0.133009
Epoch 4: train_loss=0.121099 val_loss=0.132300
Epoch 4: train_loss=0.121099 val_loss=0.132300
Epoch 5: train_loss=0.118541 val_loss=0.131882
Epoch 5: train_loss=0.118541 val_loss=0.131882
Epoch 6: train_loss=0.116033 val_loss=0.131826
Epoch 6: train_loss=0.116033 val_loss=0.131826
Epoch 7: train_loss=0.113363 val_loss=0.130568
Epoch 7: train_loss=0.113363 val_loss=0.130568
Epoch 8: train_loss=0.110353 val_loss=0.129420
Epoch 8: train_loss=0.110353 val_loss=0.129420
Epoch 9: train_loss=0.107267 val_loss=0.128645
Epoch 9: train_loss=0.107267 val_loss=0.128645
Epoch 10: train_loss=0.103967 val_loss=0.128849
Epoch 10: train_loss=0.103967 val_loss=0.128849
Epoch 11: train_loss=0.100713 val_loss=0.129115
Epoch 11: train_loss=0.100713 val_loss=0.129115
Epoch 12:

In [7]:
# Prepare database vectors and retrieval helpers
# db_words / db_vectors cover the entire dataset (used for retrieval)
db_words = [ex['word'] for ex in ds.examples]
db_vectors = np.stack([ex['vector'] for ex in ds.examples])  # shape (N, D)

# normalize db vectors (L2)
db_norms = np.linalg.norm(db_vectors, axis=1)
# avoid division by zero
db_norms[db_norms == 0] = 1.0


def retrieve_topk(definition, k=10, model=model, tokenizer=ds.tokenizer, maxlen=MAXLEN, device=DEVICE):
    # encode definition
    toks = tokenizer(definition, truncation=True, padding='max_length', max_length=maxlen, return_tensors='pt')
    input_ids = toks['input_ids'].to(device)
    attention_mask = toks['attention_mask'].to(device)
    model.eval()
    with torch.no_grad():
        q = model(input_ids=input_ids, attention_mask=attention_mask)
    q = q.cpu().numpy().reshape(-1)  # (D,)
    q_norm = np.linalg.norm(q)
    if q_norm == 0:
        q_norm = 1.0
    sims = (db_vectors @ q) / (db_norms * q_norm)
    idxs = np.argsort(-sims)[:k]
    return [(db_words[i], float(sims[i])) for i in idxs]

print('Retrieval helper ready')


Retrieval helper ready


In [9]:
# Compute recall@1,5,10 on the held-out test set
from collections import defaultdict
k_values = [1,5,10, 20]
correct = defaultdict(int)

for idx in tqdm(test_idx, desc='Evaluating recall'):
    ex = ds.examples[idx]
    gt_word = ex['word']
    topk = retrieve_topk(ex['definition'], k=max(k_values))
    retrieved = [w for w,_ in topk]
    for k in k_values:
        if gt_word in retrieved[:k]:
            correct[k] += 1

num_test = len(test_idx)
print('Num test examples:', num_test)
for k in k_values:
    print(f'Recall@{k}: {correct[k]}/{num_test} = {correct[k]/num_test:.4f}')


Evaluating recall: 100%|██████████| 98/98 [00:01<00:00, 56.67it/s]

Num test examples: 98
Recall@1: 6/98 = 0.0612
Recall@5: 15/98 = 0.1531
Recall@10: 26/98 = 0.2653
Recall@20: 36/98 = 0.3673





Demo retrievals

In [None]:
import random
import os
from pathlib import Path

from student import load_glove_embeddings, TextOnlyDataset, StudentModel
import torch
import numpy as np

GLOVE = './glove.6B.300d.txt'
DATASET = 'dataset.json'
MODEL_PATH = './ckpts/student_model.pt'

print('Loading GloVe...')
glove = load_glove_embeddings(GLOVE)
dim = next(iter(glove.values())).shape[0]
print('GloVe dim:', dim)

print('Loading dataset...')
ds = TextOnlyDataset(DATASET, glove, tokenizer_name='bert-base-uncased', max_length=128)
print('Dataset examples:', len(ds))

# build model
# model = StudentModel(bert_model_name='bert-base-uncased', target_dim=dim)
# prefer CPU for reproducibility here
device = torch.device('cpu')
model.to(device)

# try load weights if available
if Path(MODEL_PATH).exists():
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print('Loaded model weights from', MODEL_PATH)
    except Exception as e:
        print('Failed to load model weights:', e)
else:
    print('No model weights found at', MODEL_PATH, '- using randomly initialized model')

# prepare DB vectors and mapping
db_words = [ex['word'] for ex in ds.examples]
db_defs = [ex['definition'] for ex in ds.examples]
db_vectors = np.stack([ex['vector'] for ex in ds.examples])
# normalize
db_norms = np.linalg.norm(db_vectors, axis=1)
db_norms[db_norms==0] = 1.0

# retrieval function
from student import TextOnlyDataset

def retrieve_topk(definition, k=5):
    tokenizer = ds.tokenizer
    toks = tokenizer(definition, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
    input_ids = toks['input_ids'].to(device)
    attention_mask = toks['attention_mask'].to(device)
    model.eval()
    with torch.no_grad():
        q = model(input_ids=input_ids, attention_mask=attention_mask)
    q = q.cpu().numpy().reshape(-1)
    q_norm = np.linalg.norm(q)
    if q_norm == 0:
        q_norm = 1.0
    sims = (db_vectors @ q) / (db_norms * q_norm)
    idxs = np.argsort(-sims)[:k]
    return [(db_words[i], db_defs[i], float(sims[i])) for i in idxs]

In [15]:
def demo(idx):
    # pick a random definition
    query_def = ds.examples[idx]['definition']
    query_word = ds.examples[idx]['word']
    print('\nRandom query definition (from dataset index {}):\n"{}"\n'.format(rand_idx, query_def))
    print('Query word is "{}"\n'.format(query_word))

    # run retrieval
    topk = retrieve_topk(query_def, k=10)
    print('Top-10 retrieved:')
    for i,(w,d,s) in enumerate(topk,1):
        print(f"{i}. {w} (sim={s:.4f}) -- definition: {d}")

In [93]:
# pick a random definition
rand_idx = random.randrange(len(ds.examples))
demo(rand_idx)


Random query definition (from dataset index 503):
"a rugged box (usually made of wood); used for shipping"

Query word is "crate"

Top-10 retrieved:
1. crate (sim=0.6892) -- definition: a rugged box (usually made of wood); used for shipping
2. plastic_bag (sim=0.4852) -- definition: a bag made of thin plastic material
3. refrigerator (sim=0.4830) -- definition: white goods in which food can be stored at low temperatures
4. bucket (sim=0.4784) -- definition: a roughly cylindrical vessel that is open at the top
5. punching_bag (sim=0.4559) -- definition: an inflated ball or bag that is suspended and punched for training in boxing
6. carton (sim=0.4515) -- definition: a box made of cardboard; opens by flaps on top
7. sleeping_bag (sim=0.4310) -- definition: large padded bag designed to be slept in outdoors; usually rolls up like a bedroll
8. shovel (sim=0.4239) -- definition: a hand tool for lifting loose material; consists of a curved container or scoop and a handle
9. wooden_spoon (sim

In [18]:
# pick a random definition
rand_idx = random.randrange(len(ds.examples))
demo(rand_idx)


Random query definition (from dataset index 907):
"potato that has been peeled and boiled and then mashed"

Query word is "mashed_potato"

Top-10 retrieved:
1. zucchini (sim=0.4902) -- definition: small cucumber-shaped vegetable marrow; typically dark green
2. Crock_Pot (sim=0.4830) -- definition: an electric cooker that maintains a relatively low temperature
3. mashed_potato (sim=0.4733) -- definition: potato that has been peeled and boiled and then mashed
4. cauliflower (sim=0.4615) -- definition: compact head of white undeveloped flowers
5. ladle (sim=0.4563) -- definition: a spoon-shaped vessel with a long handle; frequently used to transfer liquids from one container to another
6. wok (sim=0.4535) -- definition: pan with a convex bottom; used for frying in Chinese cooking
7. cucumber (sim=0.4480) -- definition: cylindrical green fruit with thin green rind and white flesh eaten as a vegetable; related to melons
8. stove (sim=0.4414) -- definition: any heating apparatus
9. Petri_di

In [19]:
# pick a random definition
rand_idx = random.randrange(len(ds.examples))
demo(rand_idx)


Random query definition (from dataset index 174):
"English breed of strong stocky dog having a broad skull and smooth coat"

Query word is "Staffordshire_bullterrier"

Top-10 retrieved:
1. Yorkshire_terrier (sim=0.7629) -- definition: very small breed having a long glossy coat of bluish-grey and tan
2. American_Staffordshire_terrier (sim=0.7381) -- definition: American breed of muscular terriers with a short close-lying stiff coat
3. Norwich_terrier (sim=0.7319) -- definition: English breed of small short-legged terrier with a straight wiry red or grey or black-and-tan coat and erect ears
4. Irish_terrier (sim=0.7316) -- definition: medium-sized breed with a wiry brown coat; developed in Ireland
5. Norfolk_terrier (sim=0.7284) -- definition: English breed of small terrier with a straight wiry grizzled coat and dropped ears
6. Scotch_terrier (sim=0.6965) -- definition: old Scottish breed of small long-haired usually black terrier with erect tail and ears
7. soft-coated_wheaten_terrier 

In [27]:
# pick a random definition
rand_idx = random.randrange(len(ds.examples))
demo(rand_idx)


Random query definition (from dataset index 850):
"a keyboard for manually entering characters to be printed"

Query word is "typewriter_keyboard"

Top-10 retrieved:
1. typewriter_keyboard (sim=0.8040) -- definition: a keyboard for manually entering characters to be printed
2. computer_keyboard (sim=0.7213) -- definition: a keyboard that is a data input device for computers; arrangement of keys is modelled after the typewriter keyboard
3. accordion (sim=0.5960) -- definition: a portable box-shaped free-reed instrument; the reeds are made to vibrate by air from the bellows controlled by the player
4. acoustic_guitar (sim=0.5802) -- definition: sound is not amplified by electrical means
5. cello (sim=0.5788) -- definition: a large stringed instrument; seated player holds it upright while playing
6. violin (sim=0.5751) -- definition: bowed stringed instrument that is the highest member of the violin family; this instrument has four strings and a hollow body and an unfretted fingerboard a