In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import math
import random
import matplotlib.pyplot as plt

import re
from collections import Counter


In [38]:
# -----------------------------------------------------
#                    Dataset class
# -----------------------------------------------------

class TextDataset(Dataset):
    # preparing the dataset
    # -----------------------
    def __init__(self, token_ids, seq_len=20):
        # Store the tokenized data (list of integers)
        self.token_ids = token_ids
        # sequence length = context size
        # how many tokens are in one sample --> how far back the model looks in the text
        self.seq_len = seq_len

    # Number of (x, y) pairs in the dataset --> reminder: Each sample shifts by one token
    # -----------------------
    # x --> input sequence
    # y --> the “next token” targets
    def __len__(self):
        # --> use max() to avoid negative values
        return max(0, len(self.token_ids) - self.seq_len)

    # Get training example (x, y) given an index
    # -----------------------
    def __getitem__(self, idx):
        # x: input sequence of length `seq_len`
        x = self.token_ids[idx: idx + self.seq_len]
        # y: target sequence shifted by one token (the "next token" prediction)
        y = self.token_ids[idx + 1: idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

In [39]:
# -----------------------------------------------------
#         Model with transformer architecture
# -----------------------------------------------------

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=2, dim_feedforward=256, dropout=0.1, max_seq_length=20): # max_seq_length needs to be >= your seq_len
        super().__init__()
        self.model_type = "Transformer"
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_length

        # token IDs ---> vectors
        # ---------------------------
        # Token + position embeddings
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_length, d_model)

        # ---------------------------
        #    Transformer encoder 
        # ---------------------------
        # (Batch, T: sequence length at batch step, C: channels))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, # dimentionality of input embeddings
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,    # match (batch, seq_len, d_model)...here for now (64, 10, 128)
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # vectors --> logits 
        # ---------------------------
        # Final linear layer to vocab logits
        self.decoder = nn.Linear(d_model, vocab_size)

        self._reset_parameters()

    # initialize parameters of the Transformer --> random weights
    # ---------------------------
    def _reset_parameters(self):
        initrange = 0.1
        nn.init.uniform_(self.tok_emb.weight, -initrange, initrange)
        nn.init.uniform_(self.pos_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    # Mask to prevent attention to future tokens
    # ---------------------------
    def _generate_causal_mask(self, seq_length, device):
        # Shape: (T, T) with -inf above diagonal
        mask = torch.triu(torch.ones(seq_length, seq_length, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def forward(self, src):
            """
            src: LongTensor of shape (batch_size, seq_len)
            returns: logits of shape (batch_size, seq_len, vocab_size)
            """
            batch_size, seq_len = src.shape
            device = src.device

            # Token + positional embeddings
            tok_emb = self.tok_emb(src) * math.sqrt(self.d_model)   # (B, T, C)

            # Position ids: [0, 1, ..., seq_len-1]
            pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len)
            pos_emb = self.pos_emb(pos_ids)                         # (B, T, C)

            x = tok_emb + pos_emb                                   # (B, T, C)

            # Causal mask so token t can't see future tokens > t
            src_mask = self._generate_causal_mask(seq_len, device)  # (T, T)

            # Transformer encoder
            x = self.transformer_encoder(x, mask=src_mask)          # (B, T, C)

            # Project to vocabulary
            logits = self.decoder(x)                                # (B, T, vocab_size)
            return logits
    
    # Text generating Function
    # ---------------------------
    def sample(self, batch_size=1, num_steps=30, temperature=1.0, start_tokens=None, device=None):
        """
        Autoregressively sample tokens from the model.

        Returns: LongTensor of shape (batch_size, initial_len + num_steps)
        """
        # if device is not specified, use the device of the model parameters (CUDA or MPS)
        if device is None:
            device = next(self.parameters()).device

        # if no starting tokens are provided, start with a dummy token
        if start_tokens is None:
            # Start with a single dummy token (e.g. token id 0)
            x = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        # else, use the provided starting tokens
        else:
            x = start_tokens.to(device)
            if x.dim() == 1:
                x = x.unsqueeze(0).expand(batch_size, -1)  # (B, T0)

        # each loop appends one token
        for _ in range(num_steps):
            # get the logits for the current sequence, returns a probability distribution over the vocabulary at each position
            logits = self(x)                   # (B, T, V)
            logits_last = logits[:, -1, :]     # (B, V) – last time step

            # apply temperature
            logits_last = logits_last / temperature
            # convert logits to probabilities
            probs = F.softmax(logits_last, dim=-1)

            # randomly picks a token ID based on the probabilities
            next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append the predicted token to the sequence
            x = torch.cat([x, next_token], dim=1)

            # Keep the sequence from getting too long
            if x.size(1) > self.max_seq_len:
                x = x[:, -self.max_seq_len:]

        # returns the generated sequences
        return x

In [40]:
# -----------------------------------------------------
#           Model with R N N architecture
# -----------------------------------------------------

class RNNLM(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # Embedding: token IDs ---> vectors
        # --------------------------
        # (batch_size, seq_len) to (batch_size, seq_len, embedding_dim)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        # --------------------------
        #         LSTM Block
        # --------------------------
        # Takes embeddings as input --> outputs hidden states
        self.rnn = nn.LSTM(
            input_size = embedding_dim,
            hidden_size = hidden_dim,
            num_layers = 4,
            batch_first = True
        )

        # Linear projection: RNN's hidden output ---> vocab logits
        # --------------------------
        # Shape: (batch_size, seq_len, hidden_dim) → (batch_size, seq_len, vocab_size)
        # Each hidden state is a probability distribution over the vocabulary
        self.proj = nn.Linear(hidden_dim, vocab_size, bias=False)

    # --------------------------
    #         Forward pass
    # --------------------------
    def forward(self, token_ids: torch.Tensor):
        ws = self.embeddings(token_ids)
        # reminder: Each sample shifts by one token
        w0 = torch.zeros((ws.size(0), 1, self.embedding_dim), device=ws.device, dtype=ws.dtype)
        ws_shifted = torch.cat([w0, ws[:, :-1, :]], dim=1)
        # run RNN
        hidden_states, _ = self.rnn(ws_shifted)
        # get logits
        logits = self.proj(hidden_states)
        return logits

    # --------------------------
    #         Generate text
    # --------------------------
    # Generate new text (token IDs) by feeding the model's own predictions back as input
    def sample(self, batch_size=1, num_steps=20, temperature: float = 1.0): # ..................................................................... Batch size = 1 ??
        # ........................ num_steps=20 (define length of the text geenrated: the model just keeps sampling the next token)
        """
        Args:
        batch_size : how many sequences to generate at once
        num_steps  : how many tokens to generate per sequence
        temperature: controls randomness in sampling
                    (low = more deterministic, high = more random)

        Returns:
        token_ids: tensor of generated token IDs of shape (batch_size, num_steps)
        """
        device = self.embeddings.weight.device
        # Start with an empty sequence of token IDs
        token_ids = torch.zeros((batch_size, 0), device=device, dtype=torch.long)
        for t in range(num_steps):
            # Forward pass through the model to get logits (unnormalized probabilities) ..................................................!!! unnormalized?! 
            logits = self.forward(token_ids)
            # logits for the last generated token / temperature --> adjust randomness
            logits_t = logits[:, -1:, :] / temperature
            # Convert logits to a categorical probability distribution .................................................. convert to categorical? should be in forwarn pass def?
            p = torch.distributions.Categorical(logits=logits_t)
            # Sample the next token ID from that distribution
            next_tokens = p.sample()
            token_ids = torch.cat([token_ids, next_tokens], dim=1)
        return token_ids


In [41]:
# -----------------------------------------------------
#                 Training / evaluation
# -----------------------------------------------------
def train_epoch(model, loader, optimizer, criterion, device):
    """
    Args:
        model      : RNN / Transformer language model
        loader     : (x, y) batches from the dataset
        optimizer  : optimization algorithm
        criterion  : loss function
        device     : device ("cpu", "cuda", or "mps")

    Returns: Average loss value for the epoch
    """
    # trainig mode
    model.train()
    total_loss = 0.0

    # Loop over each mini-batch of data
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        # Forward pass: compute predictions (logits)
        logits = model(x)
        # compute loss ---> logits shape: (batch_size * seq_len, vocab_size), y shape: (batch_size * seq_len)
        loss = criterion(logits.view(-1, model.vocab_size), y.view(-1))
        # back propagation
        loss.backward()
        # Update model parameters
        optimizer.step()
        # Accumulate batch loss
        total_loss += loss.item()
    # returns Average loss value for the epoch
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    """
    Args:
        model      : RNN language model
        loader     : (x, y) batches from the dataset
        criterion  : loss function
        device     : device ("cpu", "cuda", or "mps")

    Returns: Average loss value across the entire dataset
    """
    # switch to eval mode
    model.eval()
    total_loss = 0.0

    # Disable gradient calculation 
    with torch.no_grad():
        # Loop over validation batches ..........................................................is is looping over a "valudation set" ??... says the same as for training loop
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            # forward pass, get logits
            logits = model(x)
            # compute loss
            loss = criterion(logits.view(-1, model.vocab_size), y.view(-1))
            # accumulate loss to compute average later 
            total_loss += loss.item()
    # return average loss over entire dataset
    return total_loss / len(loader)

In [42]:
# -----------------------------------------------------
#                       Using GPU
# -----------------------------------------------------

# PS & Mac compatible
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

if device.type == "mps":
    print("Running on Apple GPU (MPS).")
elif device.type == "cuda":
    print("Running on cuda GPU")
else:
    print("Running on CPU.")


Running on Apple GPU (MPS).


In [None]:
# -----------------------------------------------------
#                 Home made Tokenizer
# -----------------------------------------------------
#     word-level tokenizer for small sample dataset

class SimpleTokenizer:
    def __init__(self, texts, vocab_size=8000, lower=True):
        """
        texts: iterable of strings (Wikipedia articles)
        vocab_size: total vocab size including special tokens
        """
        self.lower = lower

        # Special tokens
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"
        self.bos_token = "<bos>"
        self.eos_token = "<eos>"
        special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]

        # 1) Build frequency counter over words
        counter = Counter()
        for t in texts:
            if lower:
                t = t.lower()
            # very simple tokenizer: keep words and punctuation as separate tokens
            tokens = re.findall(r"\w+|\S", t)
            counter.update(tokens)

        # 2) Keep the most common tokens (leaving room for specials)
        max_basic_tokens = vocab_size - len(special_tokens)
        most_common = [w for w, _ in counter.most_common(max_basic_tokens)]

        # 3) Build vocab
        self.itos = special_tokens + most_common               # index → token
        self.stoi = {tok: i for i, tok in enumerate(self.itos)} # token → index

    @property
    def vocab_size(self):
        return len(self.itos)

    # Encode
    # ----------------------

    def encode(self, text, add_special_tokens=True, max_length=None, truncation=False):
        if self.lower:
            text = text.lower()
        tokens = re.findall(r"\w+|\S", text)
        ids = [self.stoi.get(tok, self.stoi[self.unk_token]) for tok in tokens]

        if add_special_tokens:
            ids = [self.stoi[self.bos_token]] + ids + [self.stoi[self.eos_token]]

        if max_length is not None and truncation:
            ids = ids[:max_length]

        return ids

    # Decode
    # ----------------------
    def decode(self, ids, skip_special_tokens=True):
        tokens = []
        special = {self.pad_token, self.unk_token, self.bos_token, self.eos_token}
        for idx in ids:
            if 0 <= idx < len(self.itos):
                tok = self.itos[idx]
                if skip_special_tokens and tok in special:
                    continue
                tokens.append(tok)

        # naive detokenization: join with spaces
        text = " ".join(tokens)
        # (optional: you could clean up spaces before punctuation here)
        return text


In [None]:
# -----------------------------------------------------
#              testing with smaller dataset
# -----------------------------------------------------

# 1) crop dataset
# -----------------
# “Only use the first 1000 articles from the Wikipedia dataset”
max_articles = 200

dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split=f"train[:{max_articles}]")  # lang=en
# Example: English Wikipedia dump 20231101
# This specifies the version (dump date = Nov 2023) and language (en = English). 
# Each example has e.g. ‘text’ field (the article content) :contentReference[oaicite:4]{index=4}

# 2) Choose tokenizer
# ------------------------
# tokenizer_name = "bert-base-uncased"
# tokenizer_name = "gpt2"
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # THIS IS PRETRAINED, WE NEED TO DO IT OURSELVES ! ...................    yeps :)

# simple tokenizer
texts = dataset["text"]
tokenizer = SimpleTokenizer(texts, vocab_size=8000, lower=True)

# 3) Tokenize dataset: words ---> vector sequence of token IDs
# ------------------------
# Extract the 'text' field (the article content) from the dataset
texts = dataset["text"]

token_ids = []
for t in texts:
    # ids = tokenizer.encode(t, add_special_tokens=True) --> includes special tokens like <BOS>, <EOS>, etc.
    # - max_length=512 : limit article length to 512 tokens (keeps processing fast)
    # - truncation=True : if article is longer than 512 tokens, cut off the rest
    ids = tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True)
    # Add this article's token IDs to the main loooooong list
    token_ids.extend(ids)

# 4) Split raw data into train/test
# ------------------------

# note: don't shuffle raw data, 
# let the DataLoader do the shuffling at the sequence level during training
split_idx = int(0.9 * len(token_ids)) # (90% train, 10% test) ............................ no validation set? could be 80 train, 10 val, 10 test ?
train_ids = token_ids[:split_idx] 
test_ids = token_ids[split_idx:]


# 5) Build dataloaders
# ------------------------
seq_len = 20 # context size, how many tokens the model sees at once
batch_size = 64

# Create training and test datasets

# Structured (x, y) pairs ready for batching
    # x → a window of seq_len tokens
    # y → the same window, shifted by one token
train_ds = TextDataset(train_ids, seq_len=seq_len)
test_ds = TextDataset(test_ids, seq_len=seq_len)

# Batches of (x, y) samples for training
train_loader = DataLoader(train_ds, batch_size, shuffle=True) # ................... shuffle ?
test_loader = DataLoader(test_ds, batch_size)

# 6) Set up model
# ------------------------
vocab_size = tokenizer.vocab_size
# token embedding dimension
embedding_dim = 32
hidden_dim = 256

# ....................................................................
# calling RNN model
# model = RNNLM(vocab_size, embedding_dim, hidden_dim)
# ....................................................................

# ....................................................................
# calling Transformer model
model = TransformerLM(
    vocab_size=vocab_size, 
    d_model=embedding_dim, 
    nhead=4, 
    num_layers=2, 
    dim_feedforward=64, 
    dropout=0.1, 
    max_seq_length=seq_len)
# ....................................................................

# move model to GPU / MPS / CPU
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 7) Train & evaluate
# ------------------------
epochs = 3

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}", end=" —`")
    # Training phase --> Returns the average training loss for this epoch
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    # Evaluation phase --> Returns the average test validation loss for this epoch
    test_loss = evaluate(model, test_loader, criterion, device)
    # monitor process
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

# 8) Sample and decode
# ------------------------
# Ask the model to generate new text:
# - batch_size = number of samples to generate
# - num_steps  = how many tokens to generate sequentially
# - temperature = randomness level (1.0 = normal, >1 = random, <1 = conservative)
sample_ids = model.sample(batch_size=2, num_steps=30, temperature=1.0)

# decode and print
print("Sampled text:", [tokenizer.decode(ids.tolist()) for ids in sample_ids])

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 1/3 —`

KeyboardInterrupt: 