In [1]:
!pip install datasets transformers

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading xx

In [2]:
import math
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset
from transformers import AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
dataset = load_dataset("ag_news", split="train")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [4]:
len(dataset)

120000

In [5]:
def split_headline_body(example):
  # Try splitting using " - " if present, else use the first sentence as headline.
  if " - " in example["text"]:
      headline, body = example["text"].split(" - ", 1)
  else:
      parts = example["text"].split(".")
      headline = parts[0]
      body = ".".join(parts[1:]).strip()
  example["headline"] = headline.strip()
  example["body"] = body.strip() if body.strip() else headline.strip()
  return example

In [6]:
dataset = dataset.map(split_headline_body)

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
class NewsDataset:
  def __init__(self, dataset, tokenizer, max_input_length=32, max_output_length=128):
    self.dataset = dataset
    self.tokenizer = tokenizer
    self.max_input_length = max_input_length
    self.max_output_length = max_output_length

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    sample = self.dataset[idx]
    headline = sample["headline"]
    body = sample["body"]

    input_enc = self.tokenizer(headline,
                                truncation=True,
                                max_length=self.max_input_length,
                                return_tensors="pt")
    target_enc = self.tokenizer(body,
                                truncation=True,
                                max_length=self.max_output_length,
                                return_tensors="pt")

    # Remove the extra batch dimension.
    input_ids = input_enc.input_ids.squeeze(0)
    target_ids = target_enc.input_ids.squeeze(0)

    return {"input_ids": input_ids, "target_ids": target_ids}

news_dataset = NewsDataset(dataset, tokenizer)

In [9]:
def collate_fn(batch):
  input_ids = [item["input_ids"] for item in batch]
  target_ids = [item["target_ids"] for item in batch]
  input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
  target_ids = nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
  return input_ids, target_ids

In [10]:
dataloader = DataLoader(news_dataset,
                        batch_size=128,
                        shuffle=True,
                        collate_fn=collate_fn)

In [39]:
for i, (src, trg) in enumerate(dataloader):
  if i == 0:
    print(src)
    print(src.shape)
    print(trg)
    break

tensor([[   35, 13228,  9340,  ..., 50256, 50256, 50256],
        [30507, 33500,  8835,  ...,   428,  1285,  1303],
        [ 9914, 44394,   287,  ...,  3321,    11,  5170],
        ...,
        [13587, 24833, 38116,  ..., 50256, 50256, 50256],
        [16977,   352,    12,  ...,  3173,   319,  3139],
        [ 5956,   582,  5055,  ..., 50256, 50256, 50256]])
torch.Size([128, 32])
tensor([[  464,  8872,  3405,  ..., 50256, 50256, 50256],
        [30507, 33500,  8835,  ..., 50256, 50256, 50256],
        [ 9914, 44394,   287,  ..., 50256, 50256, 50256],
        ...,
        [ 1722,  2031,   338,  ..., 50256, 50256, 50256],
        [16977,   352,    12,  ..., 50256, 50256, 50256],
        [ 2953,  1551,   326,  ..., 50256, 50256, 50256]])


In [11]:
# Positional Encoding
class PositionalEncoding(nn.Module):
  def __init__(self,
               d_model: int,
               dropout: float = 0.1,
               max_length: int = 5000):
    """
    d_model: dimensions of the embeddings (number of values in each embedding vector)
    dropout: probability of dropout
    max_length: max length of a sequence
    """
    super().__init__()

    self.dropout = nn.Dropout(p=dropout)

    pe = torch.zeros(max_length, d_model) # (max_length, d_model)
    # Create position column
    k = torch.arange(0, max_length).unsqueeze(dim=1)
    # Use the log version of the function for positional encodings
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

    # Use sine for the even indices and cosine for the odd indices
    pe[:, 0::2] = torch.sin(k * div_term)
    pe[:, 1::2] = torch.cos(k * div_term)

    pe = pe.unsqueeze(dim=0) # Add the batch dimension

    # We use a buffer because the positional encoding is fixed and not a model paramter that we want to be updated during backpropagation.
    self.register_buffer("pe", pe) # Buffers are saved with the model state and are moved to the correct device

  def forward(self, x):
    # x shape: (batch_size, seq_length, d_model)
    # Add the positional encoding to the embeddings that are passed in
    x += self.pe[:, :x.size(1)]
    return self.dropout(x)

In [36]:
def generate_target_mask(sz):
  # Generate the target mask to prevent the decoder from peeking at future target values during training
  mask = torch.tril(torch.ones(sz, sz), diagonal=0)
  mask = mask.masked_fill(mask == 0, float('-inf'))
  return mask

print(generate_target_mask(10))
print(generate_target_mask(10).shape)

tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
torch.Size([10, 10])


In [44]:
class TransformerNewsGenerator(nn.Module):
  def __init__(self,
               vocab_size,
               pad_token_id,
               d_model=512,
               n_head=8,
               n_layers=3,
               dim_ffn=2048,
               dropout=0.1,
               max_seq_length=512):
    super().__init__()
    self.d_model = d_model
    self.pad_token_id = pad_token_id

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
    self.pos_decoder = PositionalEncoding(d_model, dropout, max_seq_length)
    self.transformer = nn.Transformer(d_model=d_model,
                                      nhead=n_head,
                                      num_encoder_layers=n_layers,
                                      num_decoder_layers=n_layers,
                                      dim_feedforward=dim_ffn,
                                      dropout=dropout)
    self.fc_out = nn.Linear(d_model, vocab_size)

  def generate_target_mask(self, seq_len):
    # Generate the target mask to prevent the decoder from peeking at future target values during training
    mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0)
    mask = mask.masked_fill(mask == 0, float('-inf'))

    # The target mask has -inf in top right triangle and 1 in bottom right triangle
    return mask

  def forward(self, src, trg):
    trg_seq_len = trg.size(1)
    trg_mask = self.generate_target_mask(trg_seq_len).to(trg.device)

    # Create key padding masks
    src_key_padding_mask = (src == self.pad_token_id)
    trg_key_padding_mask = (trg == self.pad_token_id)

    # Embedding and Positional Encodings
    src_emb = self.embedding(src) * math.sqrt(self.d_model)
    src_emb = self.pos_encoder(src_emb)
    src_emb = src_emb.transpose(0, 1) # We want a shape of (seq_len, batch_size, d_model)

    trg_emb = self.embedding(trg) * math.sqrt(self.d_model)
    trg_emb = self.pos_decoder(trg_emb)
    trg_emb = trg_emb.transpose(0, 1) # We want a shape of (seq_len, batch_size, d_model)

    output = self.transformer(src_emb,
                              trg_emb,
                              tgt_mask=trg_mask,
                              src_key_padding_mask=src_key_padding_mask,
                              tgt_key_padding_mask=trg_key_padding_mask)

    # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
    output = output.transpose(0, 1)

    output += torch.randn_like(output) * 0.001 # Random noise
    logits = self.fc_out(output)
    return logits

In [45]:
vocab_size = len(tokenizer)
model = TransformerNewsGenerator(vocab_size=vocab_size, pad_token_id=tokenizer.pad_token_id).to(device)



In [46]:
lr = 0.0005
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.AdamW(params=model.parameters(), lr=lr, weight_decay=1e-5)

In [47]:
epochs = 20
clip = 1
model.train()

best_valid_loss = float('inf')
model_path = "next_word_pred_model.pt"

if os.path.exists(model_path):
  print(f"Loading model from {model_path}...")
  model.load_state_dict(torch.load(model_path, map_location=device))
else:
  print("No saved model found. Starting training...")

  # Training
  for epoch in tqdm(range(epochs), desc="Training Progress", colour="#00ff00"):
    epoch_loss = 0

    pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1} Progress", colour="#005500")
    for i, (src, trg) in enumerate(pbar):
      src, trg = src.to(device), trg.to(device)

      # Forward pass
      # Pass in the full target sequence without the <eos> token, and in the model we use the target mask to prevent it from peeking at future values
      logits = model(src, trg[:,:-1]) # Remove <eos> token because we want to predict it ourselves

      # Expected target
      expected_output = trg[:,1:] # Remove <bos> token because the <eos> was not generated in the logits, so we need to remove it to properly compare

      # Calculate the loss
      # contiguous() flattens so that every token position in every sequence is treated as an individual prediction.
      loss = loss_fn(logits.reshape(-1, vocab_size), expected_output.reshape(-1))
      epoch_loss += loss.item()

      optimizer.zero_grad()

      # Backpropagation
      loss.backward()

      # Gradient Clipping
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

      # Optimizer Step
      optimizer.step()

      pbar.set_postfix(loss=loss.item()) # Update the loss on the tqdm progress bar

    message = f"Epoch: {epoch + 1} | Loss: {epoch_loss / len(dataloader)}"

    if epoch_loss / len(dataloader) < best_valid_loss:
      best_valid_loss = epoch_loss / len(dataloader)
      torch.save(model.state_dict(), model_path)
      message += " --> STORED"

    print(message)

No saved model found. Starting training...


Training Progress:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 Progress:   0%|          | 0/938 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [None]:
def generate_text(model, tokenizer, headline, max_length=50):
  model.eval()
  with torch.no_grad():
    # Tokenize the input headline.
    input_enc = tokenizer(headline, return_tensors="pt")
    input_ids = input_enc.input_ids.to(device)

    # Start the decoder with the EOS token
    generated = torch.tensor([[tokenizer.eos_token_id]], device=device)

    for i in range(max_length):
      outputs = model(input_ids, generated)

      # Get the logits for the last time step
      next_token_logits = outputs[:, -1, :]
      next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
      generated = torch.cat((generated, next_token), dim=1)
      if next_token.item() == tokenizer.eos_token_id:
        break
    # Decode and clean up the generated tokens.
    generated_text = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
  return generated_text

In [None]:
headline_example = "Breaking News: Russia has officially started the war"
generated_article = generate_text(model, tokenizer, headline_example)

print("Headline:", headline_example)
print("Generated Article:\n", generated_article)

In [None]:
#@title News Generator
headline_input = "Russia has officially declared war" #@param ""

generated_article = generate_text(model, tokenizer, headline_input, max_length=50)

print("Headline:", headline_example)
print("Generated Article:\n", generated_article)