In [25]:
%pip install --quiet datasets torch numpy scikit-learn wandb tqdm nbformat

Note: you may need to restart the kernel to use updated packages.


In [15]:
!wget --no-check-certificate http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip

--2025-06-13 12:02:20--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2025-06-13 12:02:20--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2025-06-13 12:02:21--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
  Issued certificate has expired.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zi

In [16]:
import numpy as np
import torch

def load_glove_embeddings(glove_file_path, vocab_limit=None):
    vocab = {}
    vectors = []

    with open(glove_file_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            parts = line.strip().split()
            if len(parts) < 10:
                continue  # Skip bad lines

            word = parts[0]
            vec = np.array(parts[1:], dtype=np.float32)

            vocab[word] = len(vectors)
            vectors.append(vec)

            if vocab_limit and len(vocab) >= vocab_limit:
                break

    embedding_weights = torch.tensor(np.stack(vectors))
    return vocab, embedding_weights

In [28]:
glove_vocab, glove_embedding_weights = load_glove_embeddings('glove.6B.100d.txt', vocab_limit=100000)

In [35]:
len(glove_vocab)

100000

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import numpy as np
import wandb
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report
from torch.nn import Embedding

In [19]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Device:", device)

Device: mps


In [5]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/anton/.netrc
[34m[1mwandb[0m: Currently logged in as: [33madergunov-grotto[0m ([33madergunov-grotto-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [41]:
def run_evaluation(vocab, embedding_weights):
    # -------- CONFIG --------
    embedding_dim = 100
    batch_size = 64
    max_length = 20
    epochs = 5

    # -------- WANDB INIT --------
    wandb.init(project="embedding-eval",
        settings=wandb.Settings(silent="true"),
        config={
            "batch_size": batch_size,
            "max_length": max_length,
            "embedding_dim": embedding_dim,
            "epochs": epochs,
        })

    # -------- EMBEDDINGS --------
    embedding_layer = nn.Embedding.from_pretrained(embedding_weights, freeze=False)

    # -------- HELPERS --------
    def tokenize(text, vocab, max_length=20):
        tokens = text.lower().split()
        indices = [vocab.get(tok, 0) for tok in tokens][:max_length]
        if len(indices) < max_length:
            indices += [0] * (max_length - len(indices))
        return indices

    class TextDataset(Dataset):
        def __init__(self, texts, labels, vocab):
            self.data = [tokenize(text, vocab, max_length) for text in texts]
            self.labels = labels

        def __len__(self):
            return len(self.labels)

        def __getitem__(self, idx):
            return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])

    class FFN(nn.Module):
        def __init__(self, embedding_layer, hidden_dim=128, num_classes=2):
            super().__init__()
            self.embeddings = embedding_layer
            self.fc1 = nn.Linear(embedding_dim, hidden_dim)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(hidden_dim, num_classes)

        def forward(self, x):
            embedded = self.embeddings(x)                  # (B, L, D)
            avg_embed = embedded.mean(dim=1)               # (B, D)
            out = self.fc1(avg_embed)
            out = self.relu(out)
            out = self.fc2(out)
            return out

    def train_and_evaluate(model, train_loader, val_loader, num_classes=2):
        model = model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        for epoch in range(epochs):
            model.train()
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        model.eval()
        all_preds, all_labels = [], []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(device)
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=1).cpu().numpy()
                all_preds.extend(preds)
                all_labels.extend(targets.numpy())

        acc = accuracy_score(all_labels, all_preds)
        report = classification_report(all_labels, all_preds, output_dict=True)

        wandb.log({
            "accuracy": acc,
            "precision_macro": report['macro avg']['precision'],
            "recall_macro": report['macro avg']['recall'],
            "f1_macro": report['macro avg']['f1-score'],
        })

        print(f"Accuracy: {acc:.4f}")
        return acc

    # -------- RUNNER --------
    def run_on_dataset(name, subset=None, split_name="train", text_key="text", label_key="label", num_classes=2):
        print(f"📦 Loading {name}" + (f"/{subset}" if subset else ""))
        if subset:
            ds = load_dataset(name, subset)
        else:
            ds = load_dataset(name)
        train_texts = ds[split_name][text_key][:4000]
        val_texts = ds["validation" if "validation" in ds else "test"][text_key][:1000]
        train_labels = ds[split_name][label_key][:4000]
        val_labels = ds["validation" if "validation" in ds else "test"][label_key][:1000]

        train_ds = TextDataset(train_texts, train_labels, vocab)
        val_ds = TextDataset(val_texts, val_labels, vocab)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size)

        model = FFN(embedding_layer, num_classes=num_classes)
        return train_and_evaluate(model, train_loader, val_loader, num_classes=num_classes)

    # -------- RUN TESTS --------
    datasets_to_test = [
        {"name": "glue", "subset": "sst2", "text_key": "sentence", "label_key": "label", "num_classes": 2},
        {"name": "trec", "text_key": "text", "label_key": "coarse_label", "num_classes": 6},
        {"name": "ag_news", "text_key": "text", "label_key": "label", "num_classes": 4},
    ]

    for ds in datasets_to_test:
        wandb.run.name = f"embedding_test_{ds['name']}"
        if "subset" in ds:
            acc = run_on_dataset(
                name=ds["name"],
                subset=ds["subset"],
                text_key=ds["text_key"],
                label_key=ds["label_key"],
                num_classes=ds["num_classes"]
            )
        else:
            acc = run_on_dataset(
                name=ds["name"],
                text_key=ds["text_key"],
                label_key=ds["label_key"],
                num_classes=ds["num_classes"]
            )
        print(f"{ds['name']} accuracy: {acc:.4f}")

    print("View run at:", wandb.run.get_url())
    wandb.finish()


In [29]:
run_evaluation(glove_vocab, glove_embedding_weights)

📦 Loading glue/sst2
Accuracy: 0.7569
glue accuracy: 0.7569
📦 Loading trec


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.7440
trec accuracy: 0.7440
📦 Loading ag_news
Accuracy: 0.8380
ag_news accuracy: 0.8380


0,1
accuracy,▂▁█
f1_macro,▆▁█
precision_macro,▆▁█
recall_macro,▅▁█

0,1
accuracy,0.838
f1_macro,0.83165
precision_macro,0.83653
recall_macro,0.82973


In [33]:
checkpoint = torch.load("../../data/cbow_final_with_vocab.pt", map_location=torch.device('cpu'))
my_vocab = checkpoint['word2idx']
my_embedding_weights = checkpoint['model_state_dict']['embeddings.weight']

In [34]:
run_evaluation(my_vocab, my_embedding_weights)

📦 Loading glue/sst2
Accuracy: 0.7511
glue accuracy: 0.7511
📦 Loading trec


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.7080
trec accuracy: 0.7080
📦 Loading ag_news
Accuracy: 0.8030
ag_news accuracy: 0.8030


0,1
accuracy,▄▁█
f1_macro,▆▁█
precision_macro,▆▁█
recall_macro,▆▁█

0,1
accuracy,0.803
f1_macro,0.79488
precision_macro,0.79794
recall_macro,0.79397


In [37]:
random_embeddings = torch.randn_like(my_embedding_weights)

In [42]:
run_evaluation(my_vocab, random_embeddings)

📦 Loading glue/sst2
Accuracy: 0.6548
glue accuracy: 0.6548
📦 Loading trec


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5220
trec accuracy: 0.5220
📦 Loading ag_news




Accuracy: 0.6750
ag_news accuracy: 0.6750
View run at: https://wandb.ai/adergunov-grotto-personal/embedding-eval/runs/3gu43e9i
