### HIGH LEVEL OUTLINE
- Pull data from text8 (and from HN if wanted)
- Tokenise this data
- Create dataset using text8 data and a rolling context window
- Connect to WandB
- Use env variables
- Train model on dataset
- Save embeddings

In [2]:
import os
import sys
import logging
from pathlib import Path
from typing import Union

import urllib.request
from zipfile import ZipFile

DATA_DIR = os.getenv("DATA_DIR", ".data")
TRAIN_PROCESSED_FILENAME = "hn_posts_train_processed.parquet"
TEXT8_ZIP_URL = "http://mattmahoney.net/dc/text8.zip"
TEXT8_FILENAME = "text8"
TEXT8_EXPECTED_LENGTH = 100_000_000  # text8 is expected to be 100 million chars

# manually set env vars also used in data ingest (see .env.example / db-utils.py)
TITLES_FILE = "hn_posts_titles.parquet"
MINIMAL_FETCH_ONLY_TITLES = True

logger = logging.getLogger(__name__)


def get_text8(
    cache_dir: Union[str, Path] = DATA_DIR,
    text8_filename: str = TEXT8_FILENAME,
) -> Path:
    """
    Ensure the Matt Mahoney 'text8' corpus is present locally, downloading it
    once and caching it for subsequent runs.
    """
    cache_dir = Path(cache_dir).expanduser().resolve()
    cache_dir.mkdir(parents=True, exist_ok=True)
    zip_path = cache_dir / f"{text8_filename}.zip"
    txt_path = cache_dir / text8_filename

    # txt already present? then we are done
    if txt_path.exists():
        logger.info(f"text8 file found - using cached data at {txt_path}")
        return txt_path

    # ensure full text8 file is present, and otherwise download and extract as necessary
    if not zip_path.exists():
        logger.info(f"Downloading text8 corpus to {zip_path}...")
        urllib.request.urlretrieve(TEXT8_ZIP_URL, zip_path)
    logger.info(f"Extracting {zip_path} to {txt_path}...")
    with ZipFile(zip_path, "r") as zf:
        zf.extract(text8_filename, cache_dir)
    return txt_path




In [3]:
txt8_path = get_text8()

In [4]:
from collections import Counter
import random
import logging

logger = logging.getLogger(__name__)

UNK_TOKEN = "<UNK>"
PUNCTUATION_MAP = {
    "<": "<LESS>",
    ">": "<GREATER>",
    ",": "<COMMA>",
    ".": "<PERIOD>",
    "!": "<EXCLAMATION>",
    "?": "<QUESTION>",
    ":": "<COLON>",
    ";": "<SEMICOLON>",
    "-": "<DASH>",
    "(": "<LPAREN>",
    ")": "<RPAREN>",
    "[": "<LBRACKET>",
    "]": "<RBRACKET>",
    "{": "<LBRACE>",
    "}": "<RBRACE>",
    '"': "<QUOTE>",
    "'": "<APOSTROPHE>",
    "/": "<SLASH>",
    "\\": "<BACKSLASH>",
    "&": "<AMPERSAND>",
    "@": "<AT>",
    "#": "<HASH>",
    "$": "<DOLLAR>",
    "%": "<PERCENT>",
    "*": "<ASTERISK>",
    "+": "<PLUS>",
    "=": "<EQUALS>",
    "|": "<PIPE>",
    "~": "<TILDE>",
    "`": "<BACKTICK>",
}


# TODO: can we improve? e.g. remove stop words, stem/lemmatise
def tokenise(text: str) -> list[str]:
    """
    Tokenises a long string of text by lowercasing, replacing punctuation with predefined angle bracket words.

    Args:
        text (str): A single string.

    Returns:
        dict: A dictionary mapping each word to a unique index.
    """
    # Convert to lowercase
    text = text.lower()

    # Replace all punctuation with angle bracket words
    for punct, replacement in PUNCTUATION_MAP.items():
        text = text.replace(punct, f" {replacement} ")

    # Split into words (handles multiple spaces)
    words = text.split()
    return words


def build_vocab(
    tokens: list[str],
    min_freq: int = 5,
    subsampling_threshold: float = 1e-5,
) -> dict[str, int]:
    """
    Builds a vocabulary of words that appear more than the frequency threshold.
    """
    word_counts = Counter(tokens)
    # Remove words with frequency below threshold
    token_list = [UNK_TOKEN] + [
        word for word, count in word_counts.items() if count >= min_freq
    ]
    num_discarded_freq = len(word_counts) - len(token_list)

    # Frequency subsampling (remove frequent words with probability proportional to their frequency)
    total_count = sum(Counter(tokens).values())
    subsampled = []
    freqs = Counter(tokens)
    discarded_count_subsampling = 0  # Counter for discarded tokens
    for word in token_list:
        if word == UNK_TOKEN:
            subsampled.append(word)
            continue
        freq = freqs[word] / total_count
        prob_discard = 1 - (subsampling_threshold / freq) ** 0.5
        if random.random() > prob_discard:
            subsampled.append(word)
        else:
            discarded_count_subsampling += 1  # Increment counter if discarded
    token_list = subsampled

    vocab = {word: idx for idx, word in enumerate(token_list)}

    # Report
    logger.info(f"Total tokens in: {len(tokens)}")
    logger.info(f"Number discarded from frequency threshold: {num_discarded_freq} ({num_discarded_freq / len(word_counts) * 100:.2f}%)")
    # logger.info(f"Number discarded from subsampling: {discarded_count_subsampling} ({discarded_count_subsampling / len(token_list) * 100:.2f}%)")
    logger.info(f"Vocab size: {len(vocab)}")

    return vocab


def get_tokens_as_indices(tokens: list[str], vocab: dict) -> list[int]:
    """
    Converts a list of tokens to their corresponding indices using the provided vocab mapping.
    This is to ensure we have fast, random-access, constant-sized, GPU-friendly data upfront.
    """
    unk = vocab[UNK_TOKEN]
    return [vocab.get(t, unk) for t in tokens]


def get_words_from_indices(indeces: list[int], vocab: dict) -> list[str]:
    """
    Converts a list of token indeces to a list of token values
    """
    return [
        list(vocab.keys())[list(vocab.values()).index(idx)]
        for idx in indeces
        if idx in vocab.values()
    ]


In [5]:
def generate_skipgram_pairs(corpus, context_size):
    pairs = []
    for i in range(context_size, len(corpus) - context_size):
        center = corpus[i]
        context = corpus[i - context_size:i] + corpus[i + 1:i + context_size + 1]
        for ctx in context:
            pairs.append((center, ctx))
    return pairs


def build_sgram_dataset(context_size: int = 5):
    # read text8 file
    # Read the text8 file
    with open(txt8_path, "r", encoding="utf-8") as f:
        text = f.read()

    text_tokens = tokenise(text)

    # Build the vocabulary
    vocab = build_vocab(text_tokens, min_freq=0, subsampling_threshold=1e-4)
    
    print(text_tokens[:10])  # Print first 10 tokens for verification

    text_token_inds = get_tokens_as_indices(text_tokens, vocab)

    print(text_token_inds[:10])  # Print first 10 token indices for verification
    # print(len(text_token_inds))

    # Generate skip-gram pairs
    skipgram_pairs = generate_skipgram_pairs(text_token_inds, context_size)
    print(skipgram_pairs[:10])  # Print first 10 pairs for verification
    
    # print(len(skipgram_pairs))  # Print total number of pairs
    return skipgram_pairs, vocab


In [6]:
ds_pairs, ds_vocab = build_sgram_dataset()


['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against']
[1, 2, 3, 0, 0, 0, 4, 0, 0, 0]
[(0, 1), (0, 2), (0, 3), (0, 0), (0, 0), (0, 4), (0, 0), (0, 0), (0, 0), (0, 0)]


In [None]:
%pip install torch scikit-learn tqdm
import torch
import logging
from sklearn.model_selection import train_test_split
from tqdm import tqdm
logger = logging.getLogger(__name__)
def train_skipgram_model(
    skipgram_pairs: list[tuple[int, int]],
):
    """
    Train a skip-gram model on the provided skipgram pairs.
    """
    
    # This is a placeholder for the actual training logic.
    # You would typically use a library like PyTorch or TensorFlow to implement the model.
    logger.info("Training skip-gram model on provided pairs...")

    # Example: Use a neural network to learn word embeddings based on skipgram pairs
    # For now, we just log the number of pairs
    logger.info(f"Number of skipgram pairs: {len(skipgram_pairs)}")
    return

class SkipGramDataset(torch.utils.data.Dataset):
    def __init__(self, skipgram_pairs):
        self.pairs = skipgram_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        center, context = self.pairs[idx]
        return torch.tensor(center, dtype=torch.long), torch.tensor(context, dtype=torch.long)

class SkipGramModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.in_embeddings = torch.nn.Embedding(vocab_size, embedding_dim)
        self.out_embeddings = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, center, context):
        center_emb = self.in_embeddings(center)
        context_emb = self.out_embeddings(context)
        score = torch.sum(center_emb * context_emb, dim=1)
        return score

# Split skipgram_pairs into train, val, test sets (80/10/10 split)
def split_dataset(pairs, seed=42):
    train_pairs, temp_pairs = train_test_split(pairs, test_size=0.2, random_state=seed)
    val_pairs, test_pairs = train_test_split(temp_pairs, test_size=0.5, random_state=seed)
    return train_pairs, val_pairs, test_pairs

train_pairs, val_pairs, test_pairs = split_dataset(ds_pairs)

def train_skipgram_epochs(skipgram_pairs, vocab_size, embedding_dim=100, batch_size=64, lr=0.01, epochs=1):
    dataset = SkipGramDataset(skipgram_pairs)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = SkipGramModel(vocab_size, embedding_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for center, context in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            # Positive samples
            pos_labels = torch.ones(center.size(0))
            pos_score = model(center, context)
            pos_loss = loss_fn(pos_score, pos_labels)
            # Negative sampling
            neg_context = torch.randint(0, vocab_size, context.size())
            neg_labels = torch.zeros(center.size(0))
            neg_score = model(center, neg_context)
            neg_loss = loss_fn(neg_score, neg_labels)
            loss = pos_loss + neg_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        logger.info(f"Epoch {epoch+1}/{epochs} - Average loss: {avg_loss:.4f}")
    return model

# Example test with dummy data for several epochs
model = train_skipgram_epochs(train_pairs, len(ds_vocab), epochs=1)


Note: you may need to restart the kernel to use updated packages.
