In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

In [2]:
# --- Dataset class
class TextDataset(Dataset):
    def __init__(self, token_ids, seq_len=20):
        self.token_ids = token_ids
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.token_ids) - self.seq_len)

    def __getitem__(self, idx):
        x = self.token_ids[idx: idx + self.seq_len]
        y = self.token_ids[idx + 1: idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

In [3]:
class RNNLM(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        self.rnn = nn.LSTM(
            input_size = embedding_dim,
            hidden_size = hidden_dim,
            num_layers = 4,
            batch_first = True
        )

        self.proj = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor):
        ws = self.embeddings(token_ids)
        # shift right by one
        w0 = torch.zeros((ws.size(0), 1, self.embedding_dim), device=ws.device, dtype=ws.dtype)
        ws_shifted = torch.cat([w0, ws[:, :-1, :]], dim=1)
        hidden_states, _ = self.rnn(ws_shifted)
        logits = self.proj(hidden_states)
        return logits

    def sample(self, batch_size=1, num_steps=20, temperature: float = 1.0):
        device = self.embeddings.weight.device
        token_ids = torch.zeros((batch_size, 0), device=device, dtype=torch.long)
        for t in range(num_steps):
            logits = self.forward(token_ids)
            logits_t = logits[:, -1:, :] / temperature
            p = torch.distributions.Categorical(logits=logits_t)
            next_tokens = p.sample()
            token_ids = torch.cat([token_ids, next_tokens], dim=1)
        return token_ids


In [4]:
# --- Training / evaluation
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, model.vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, model.vocab_size), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_articles = 1000
# 1) Load Wikimedia dataset
# Example: English Wikipedia dump 20231101
dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split=f"train[:{max_articles}]")  # lang=en
# Each example has e.g. ‘text’ field (the article content) :contentReference[oaicite:4]{index=4}

# 2) Choose tokenizer
# tokenizer_name = "bert-base-uncased"
tokenizer_name = "gpt2"

# THIS IS PRETRAINED, WE NEED TO DO IT OURSELVES !
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# 3) Tokenize dataset (we’ll use a subset for speed)
texts = dataset["text"]
token_ids = []
for t in texts:
    # ids = tokenizer.encode(t, add_special_tokens=True)
    ids = tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True)
token_ids.extend(ids)

# 4) Split into train/test
split_idx = int(0.9 * len(token_ids))
train_ids = token_ids[:split_idx]
test_ids = token_ids[split_idx:]

# 5) Build dataloaders
seq_len = 10
train_ds = TextDataset(train_ids, seq_len=seq_len)
test_ds = TextDataset(test_ids, seq_len=seq_len)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64)

# 6) Set up model
vocab_size = tokenizer.vocab_size
embedding_dim = 128
hidden_dim = 256
model = RNNLM(vocab_size, embedding_dim, hidden_dim)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 7) Train & evaluate
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}", end=" —`")
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    test_loss = evaluate(model, test_loader, criterion, device)
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

# 8) Sample and decode
sample_ids = model.sample(batch_size=2, num_steps=30, temperature=1.0)
print("Sampled text:", [tokenizer.decode(ids.tolist()) for ids in sample_ids])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/41 [00:00<?, ?files/s]

20231101.en/train-00000-of-00041.parquet:   0%|          | 0.00/420M [00:00<?, ?B/s]

20231101.en/train-00001-of-00041.parquet:   0%|          | 0.00/351M [00:00<?, ?B/s]

20231101.en/train-00002-of-00041.parquet:   0%|          | 0.00/329M [00:00<?, ?B/s]

20231101.en/train-00003-of-00041.parquet:   0%|          | 0.00/331M [00:00<?, ?B/s]

20231101.en/train-00004-of-00041.parquet:   0%|          | 0.00/307M [00:00<?, ?B/s]

20231101.en/train-00005-of-00041.parquet:   0%|          | 0.00/244M [00:00<?, ?B/s]

20231101.en/train-00006-of-00041.parquet:   0%|          | 0.00/266M [00:00<?, ?B/s]

20231101.en/train-00007-of-00041.parquet:   0%|          | 0.00/228M [00:00<?, ?B/s]

20231101.en/train-00008-of-00041.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

20231101.en/train-00009-of-00041.parquet:   0%|          | 0.00/227M [00:00<?, ?B/s]

20231101.en/train-00010-of-00041.parquet:   0%|          | 0.00/234M [00:00<?, ?B/s]

20231101.en/train-00011-of-00041.parquet:   0%|          | 0.00/232M [00:00<?, ?B/s]

20231101.en/train-00012-of-00041.parquet:   0%|          | 0.00/239M [00:00<?, ?B/s]

20231101.en/train-00013-of-00041.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

20231101.en/train-00014-of-00041.parquet:   0%|          | 0.00/223M [00:00<?, ?B/s]

20231101.en/train-00015-of-00041.parquet:   0%|          | 0.00/235M [00:00<?, ?B/s]

20231101.en/train-00016-of-00041.parquet:   0%|          | 0.00/503M [00:00<?, ?B/s]

20231101.en/train-00017-of-00041.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

20231101.en/train-00018-of-00041.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

20231101.en/train-00019-of-00041.parquet:   0%|          | 0.00/195M [00:00<?, ?B/s]

20231101.en/train-00020-of-00041.parquet:   0%|          | 0.00/225M [00:00<?, ?B/s]

20231101.en/train-00021-of-00041.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

20231101.en/train-00022-of-00041.parquet:   0%|          | 0.00/202M [00:00<?, ?B/s]

20231101.en/train-00023-of-00041.parquet:   0%|          | 0.00/213M [00:00<?, ?B/s]

20231101.en/train-00024-of-00041.parquet:   0%|          | 0.00/221M [00:00<?, ?B/s]

20231101.en/train-00025-of-00041.parquet:   0%|          | 0.00/221M [00:00<?, ?B/s]

20231101.en/train-00026-of-00041.parquet:   0%|          | 0.00/208M [00:00<?, ?B/s]

20231101.en/train-00027-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

20231101.en/train-00028-of-00041.parquet:   0%|          | 0.00/188M [00:00<?, ?B/s]

20231101.en/train-00029-of-00041.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

20231101.en/train-00030-of-00041.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

20231101.en/train-00031-of-00041.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

20231101.en/train-00032-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

20231101.en/train-00033-of-00041.parquet:   0%|          | 0.00/203M [00:00<?, ?B/s]

20231101.en/train-00034-of-00041.parquet:   0%|          | 0.00/219M [00:00<?, ?B/s]

20231101.en/train-00035-of-00041.parquet:   0%|          | 0.00/224M [00:00<?, ?B/s]

20231101.en/train-00036-of-00041.parquet:   0%|          | 0.00/610M [00:00<?, ?B/s]

20231101.en/train-00037-of-00041.parquet:   0%|          | 0.00/674M [00:00<?, ?B/s]

20231101.en/train-00038-of-00041.parquet:   0%|          | 0.00/538M [00:00<?, ?B/s]

20231101.en/train-00039-of-00041.parquet:   0%|          | 0.00/465M [00:00<?, ?B/s]

20231101.en/train-00040-of-00041.parquet:   0%|          | 0.00/422M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6407814 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Epoch 1/3 —`Train Loss: 10.6706 | Test Loss: 10.1190
Epoch 2/3 —`Train Loss: 8.2063 | Test Loss: 7.9917
Epoch 3/3 —`Train Loss: 5.8925 | Test Loss: 7.7822
Sampled text: [' occupations Freeman include the– People males of women p the the prohibitions exclusively. viewsering However policeman nerves sexering with a pen sex among p sex of', 'cv cleaners cultures is insertion nerve nerve sex sex sex and expressedStrong nervegging leveled therefore the so. of condom and the considered sensory sensory anal denoteicone']
