# Lesson 6 (Bonus) â€” Bag of Embeddings (BoE) Classifier

In the previous notebook, we used a powerful LSTM model. However, sometimes a simple model is enough!

In this notebook, we will build a **Bag of Embeddings** model. It works like this:
1. **Lookup**: Get the word embedding for each word in the sentence.
2. **Aggregate**: Sum (or average) all the embeddings into one vector representing the whole sentence.
3. **Classify**: Pass that single vector through a linear layer to predict the class.

We will still use **pretrained embeddings** (from Wikitext-103) to ensure the model knows what words mean before we even start training.

In [None]:
!pip install fastai

In [1]:
from fastai.text.all import *
import pandas as pd

## 1. Load Data
We use the same AG News dataset.

In [2]:
path = untar_data(URLs.AG_NEWS)
df = pd.read_csv(path/'train.csv', header=None, names=['label', 'title', 'description'])
df['text'] = df['title'] + " " + df['description']

dls = TextDataLoaders.from_df(
    df, 
    text_col='text', 
    label_col='label', 
    valid_pct=0.2, 
    bs=64
)
dls.show_batch(max_n=3)

Unnamed: 0,text,category
0,"xxbos xxmaj kyoto is xxmaj dead - xxmaj long xxmaj live xxmaj xxunk xxmaj there 's troubling news ( ft subscription xxunk , alternate copy here ) coming from xxmaj japan , where the xxmaj kyoto protocol on xxmaj greenhouse xxmaj emissions was born in 1997 . xxmaj it seems that the xxmaj japanese are n't going to be able to meet their emissions targets specified in the agreement in time . xxmaj indeed , unless they buy a "" large quantity "" of emissions credits from other countries , they 're not going to be able to meet their commitment at all . xxmaj xxunk xxmaj sugiyama , a climate expert at the xxmaj central xxmaj research xxmaj institute of xxmaj electric xxmaj power xxmaj industry in xxmaj japan , said emissions were rising 1 per cent a year due to a larger - than - expected impact from",4
1,"xxbos xxmaj cape xxmaj clear boosts business processes in xxup esb xxmaj cape xxmaj clear xxmaj software this week is upgrading its xxup esb with the release of xxmaj cape xxmaj clear 6 , enabling development of business process xxunk based on xxunk > advertisement < / p><p><img src=""http : / / ad.doubleclick.net / ad / idg.us.ifw.general / solaris;sz=1x1;ord=200301151450 ? "" width=""1 "" height=""1 "" border=""0 "" / > < a href=""http : / / ad.doubleclick.net / clk;12204780;10550054;n?http : / / ad.doubleclick.net / clk;12165994;105 xxrep 3 2 95;g?http : / / xxrep 3 w .sun.com / solaris10"">solaris 10(tm ) xxup os : xxmaj position your business ten moves ahead . < / a><br / > solaris 10 xxup os has arrived and provides even more \ reasons for the world 's most demanding businesses \ to operate on this , the leading xxup unix platform . \ xxmaj like the",4
2,"xxbos xxmaj wireless xxmaj san xxmaj francisco \ \ xxmaj reuters has a story about xxmaj gavin xxmaj newsom finally getting wifi and wanting to \ hookup xxmaj san xxmaj francisco : \ "" san xxup francisco ( reuters ) - xxmaj san xxmaj francisco xxmaj mayor xxmaj gavin xxmaj newsom has set a goal of \ providing free wireless xxmaj internet activity in his city that sees itself as a \ vanguard of the xxmaj internet revolution . "" \ "" "" we will not stop until every xxmaj san xxmaj xxunk has access to free wireless \ xxmaj internet service , "" he said in his annual state of the city address on \ xxmaj thursday . "" these technologies will connect our residents to the skills and the \ jobs of the new economy . "" \ \ xxmaj the issue i have is that xxmaj i",4


## 2. Get Pretrained Embeddings (Word2Vec / GloVe)

Instead of using the AWD-LSTM embeddings, we can use standard **Word2Vec** or **GloVe** vectors. 
We will load them using `gensim`, create a matrix that maps our dataset's vocabulary to these pretrained vectors, and use that to initialize our model.

In [4]:
# Install gensim if needed
!pip install gensim

import gensim.downloader as api

# Load pretrained vectors (we use GloVe 100d as a lightweight proxy for Word2Vec)
# The official Word2Vec is very large (1.5GB+), so we use a smaller high-quality one here.
print("Loading pretrained vectors...")
word_vectors = api.load("glove-wiki-gigaword-100") 
print("Loaded!")

def create_emb_matrix(vocab, word_vectors, emb_dim):
    vocab_size = len(vocab)
    # Initialize with random weights (standard normal)
    matrix = torch.randn((vocab_size, emb_dim))
    
    hits = 0
    for i, word in enumerate(vocab):
        try:
            # Check if word exists in pretrained vectors
            if word in word_vectors:
                matrix[i] = torch.tensor(word_vectors[word])
                hits += 1
        except KeyError:
            pass
            
    print(f"Loaded {hits} / {vocab_size} words ({hits/vocab_size:.1%}) from pretrained embeddings.")
    return matrix

# Create the matrix matching our dataset's vocabulary
emb_dim = 100
pretrained_weights = create_emb_matrix(dls.vocab[0], word_vectors, emb_dim)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Loading pretrained vectors...
Loaded!
Loaded 34887 / 36760 words (94.9%) from pretrained embeddings.


## 3. Define the Bag of Embeddings Model

We define a custom PyTorch module that sums the embeddings and applies a linear classifier.

In [None]:
class BoEWrapper(nn.Module):
    def __init__(self, vocab_size, emb_dim, n_classes, pad_idx=1):
        super().__init__()
        # 1. Embedding Layer
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        # 2. Linear Layer
        self.linear = nn.Linear(emb_dim, n_classes)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len)
        with torch.no_grad():
            embeddings = self.emb(x) # (batch, seq, emb_dim)
        
        # Mask padding (we don't want to sum the padding tokens)
        # The padding token usually has index 1, but we look it up to be safe
        mask = (x != self.emb.padding_idx).unsqueeze(-1) # (batch, seq, 1)
        
        # Zero out padding embeddings
        embeddings = embeddings * mask.float()
        
        # 3. Mean along sequence dimension (dim=1)
        summed = embeddings.mean(dim=1) # (batch, emb_dim)
        
        # 4. Linear layer -> logits
        return self.linear(summed)

In [12]:
# Parameters
vocab_size = len(dls.vocab[0])
emb_dim = 100                  # We used GloVe 100d
n_classes = 4                  # AG News has 4 classes
pad_idx = dls.vocab[0].index('xxpad') # Find the padding index

# Initialize model
boe_model = BoEWrapper(vocab_size, emb_dim, n_classes, pad_idx=pad_idx)

# Load the pretrained weights we created
boe_model.emb.weight.data.copy_(pretrained_weights)
print("Model initialized with Word2Vec/GloVe weights!")

Model initialized with Word2Vec/GloVe weights!


## 4. Train

We wrap our custom model in a `Learner` and train. Since this is a simple linear layer on top of embeddings, it trains very fast!

In [13]:
# Create Learner
simple_learn = Learner(dls, boe_model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)

# Train
simple_learn.fit_one_cycle(5, 5e-3)

epoch,train_loss,valid_loss,accuracy,time
0,0.394418,0.410723,0.872,00:39
1,0.366155,0.36893,0.879708,00:30
2,0.346865,0.359232,0.884708,00:29
3,0.350602,0.357008,0.884667,00:29
4,0.348423,0.356335,0.8845,00:29


## 5. Test

In [14]:
topics = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tech"}

def predict_simple(text):
    # 1. Tokenize
    tokenized = dls.tokenizer(text)
    # 2. Map to IDs
    ids = [dls.vocab[0].index(t) if t in dls.vocab[0] else 0 for t in tokenized]
    # 3. Create tensor batch
    tensor_in = torch.tensor(ids).unsqueeze(0).to(dls.device)
    
    # 4. Predict
    with torch.no_grad():
        logits = boe_model(tensor_in)
        probs = F.softmax(logits, dim=1)
        pred_idx = probs.argmax().item()
    
    # Output
    print(f"Text: '{text}'")
    print(f"Topic: {topics[pred_idx+1]} ({probs.max():.2f})")

predict_simple("Stock markets hit a new high today.")
predict_simple("The goalkeeper saved the penalty kick.")
predict_simple("Astronomers discovered a new black hole.")

Text: 'Stock markets hit a new high today.'
Topic: Business (0.99)
Text: 'The goalkeeper saved the penalty kick.'
Topic: Sports (1.00)
Text: 'Astronomers discovered a new black hole.'
Topic: Sci/Tech (1.00)
