In [None]:
from dataclasses import dataclass
import torch
import inspect
import torch.nn as nn
import torch.nn.functional as F
import math
import transformers
import matplotlib.pyplot as plt

# Check if GPU is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
from datasets import load_dataset, VerificationMode
from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

# Load the entire dataset
# dataset = load_dataset("...", verification_mode=VerificationMode.NO_CHECKS)
dataset = load_dataset("argilla/cnn-dailymail-summaries", split='train', streaming=True)
print(dataset)


i = 0
max = 3

news = []
summaries = []
for d in dataset:
  i += 1
  news.append(d['article'])
  summaries.append(d['highlights'])
  if i == max:
    break
print(f'Number of news: {len(news)}')
print(f'Number of summaries: {len(summaries)}')

In [None]:
"""Tokens <s> and </s> are added from default in BartTokenizer"""

# Padding was requried due to different size of inputs (each article is different): https://huggingface.co/docs/transformers/en/pad_truncation
inputs = tokenizer(news, padding='longest', truncation=True, max_length=1024)['input_ids']
targets = tokenizer(summaries, padding='longest', truncation=True, max_length=128)['input_ids']  # --> padding to the longest text but shorter or equal max_len (128)
print(tokenizer.decode(inputs[0]))
print(tokenizer.decode(targets[0]))


inputs = torch.tensor(inputs)
targets = torch.tensor(targets)
print(inputs.size())
print(targets.size())

max_input_seq = inputs.size(1)
max_output_seq = targets.size(1)
# max_input_seq = 1024
# max_output_seq = 128
print(f'Max input seq: {max_input_seq}')
print(f"Max output seq: {max_output_seq}")
vocab_size = tokenizer.vocab_size
print(f'Vocab size: {vocab_size}')

In [None]:
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)

BATCH_SIZE = 8

class CustomDataset(Dataset):
  def __init__(self, inputs, targets):
    self.inputs = inputs
    self.targets = targets

  def __len__(self):
    return len(self.inputs)

  def __getitem__(self, idx):
    return self.inputs[idx], self.targets[idx]

dataset = CustomDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# for x, y in dataloader:
#   print(f"Inputs: {x.size()}, Targets: {y.size()}")

num_batches = len(dataloader)
print(f"Num batches: {num_batches}")

In [None]:
emb_dim = 16
n_layers = 6
n_heads = 4
dropout = 0.2

In [None]:
# Generate mask for <pad> tokens
def generate_padding_mask(idx):
  """unsqueese is used to adjust size of mask to self-attention mechanism"""
  padding_mask = (idx == tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)  # unsqueese adds a new dimension 1 to defined position (from 0 to x)
  return padding_mask

# Generate a causal mask for target sequences
def generate_tgt_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool().unsqueeze(0).unsqueeze(1)
    return mask

In [None]:
class CasualSelfAttention(nn.Module):
  def __init__(self):
      super().__init__()
      assert emb_dim % n_heads == 0
      self.num_heads = n_heads
      self.head_dim = emb_dim // n_heads
      self.query = nn.Linear(emb_dim, emb_dim)
      self.key = nn.Linear(emb_dim, emb_dim)
      self.value = nn.Linear(emb_dim, emb_dim)
      self.fc = nn.Linear(emb_dim, emb_dim)

  def forward(self, query, key, value, mask=None):
      batch_size = query.size(0)
      # print(mask)

      def split_heads(x):
          return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

      query, key, value = map(split_heads, [self.query(query), self.key(key), self.value(value)])

      scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
      if mask is not None:
          scores = scores.masked_fill(mask == 1, float('-inf'))  # if in mask appears True (e.g becasue of pad token) then -inf is added
      attention_weights = torch.softmax(scores, dim=-1)
      attention_output = torch.matmul(attention_weights, value).transpose(1, 2).contiguous()
      attention_output = attention_output.view(batch_size, -1, self.num_heads * self.head_dim)

      return self.fc(attention_output)

In [None]:
torch.manual_seed(42)

class EncoderLayer(nn.Module):
  def __init__(self):
    super().__init__()
    self.ln_1 = nn.LayerNorm(emb_dim)  # layer norm before self attention
    self.self_attention = CasualSelfAttention()  #multi-head attention
    self.ln_2 = nn.LayerNorm(emb_dim)  # layer norm after self attention
    self.feed_forward = nn.Sequential(
        nn.Linear(emb_dim, emb_dim * 4),
        nn.GELU(approximate='tanh'),  # TODO: why tanh ?
        nn.Linear(emb_dim * 4, emb_dim))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, src_mask):  # src_mask is mask for inputs
    # 1.
    x = self.ln_1(x)
    attn_output = self.self_attention(x, x, x, src_mask)  # Self-attention without mask
    x = self.dropout(x)
    x = x + attn_output
    # 2.
    x = self.ln_2(x)
    ff_output = self.feed_forward(x)
    x = x + self.dropout(ff_output)
    return x

class Encoder(nn.Module):
  # def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
  def __init__(self):
    super().__init__()
    self.input_emb = nn.Embedding(vocab_size, emb_dim)  # embeddings for max 50265 tokens with emb_dim dimension
    self.pos_emb = nn.Embedding(max_input_seq, emb_dim)  # positional encoding (max_input_seq is the biggest article in input)
    self.encoder_layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

  def forward(self, idx, src_mask):
    B, T = idx.shape  # Batch size and sequence length

    # Token embeddings
    input_embeddings = self.input_emb(idx)  # Shape: [B, T, emb_dim]

    # Positional embeddings
    positions = torch.arange(0, T, device=idx.device)
    positions = positions.unsqueeze(0)  # Shape: [1, T]

    positional_embeddings = self.pos_emb(positions)  # Shape: [1, T, emb_dim]

    # Combine token and positional embeddings
    x = input_embeddings + positional_embeddings  # Shape: [B, T, emb_dim]

    # Encoder layers
    for layer in self.encoder_layers:
      x = layer(x, src_mask)
    return x

# enc = Encoder()
# for input, target in dataloader:
#   print(input.shape)
#   src_mask = generate_padding_mask(input)
#   print(src_mask.shape)
#   print(src_mask)
#   x = enc(input, src_mask)
#   print('Shape of x:', x.shape)
#   break

In [None]:
torch.manual_seed(42)
class DecoderLayer(nn.Module):
  def __init__(self):
    super().__init__()
    self.ln_1 = nn.LayerNorm(emb_dim)
    self.self_attention = CasualSelfAttention()
    self.ln_2 = nn.LayerNorm(emb_dim)
    self.cross_attention = CasualSelfAttention()
    self.ln_3 = nn.LayerNorm(emb_dim)
    self.feed_forward = nn.Sequential(
        nn.Linear(emb_dim, emb_dim * 4),
        nn.GELU(approximate='tanh'),
        nn.Linear(emb_dim * 4, emb_dim)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, encoder_output, tgt_mask, src_mask):
    # Self att
    x = self.ln_1(x)
    attn_output = self.self_attention(x, x, x, tgt_mask)
    x = x + self.dropout(attn_output)

    # Cross att
    x = self.ln_2(x)
    attn_output = self.cross_attention(x, encoder_output, encoder_output, src_mask)
    x = x + self.dropout(attn_output)

    # Feed forward
    x = self.ln_3(x)
    ff_output = self.feed_forward(x)
    x = x + self.dropout(ff_output)
    return x

class Decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.output_emb = nn.Embedding(vocab_size, emb_dim)  # embeddings for max 50265 tokens with emb_dim dimension
    self.pos_emb = nn.Embedding(max_output_seq, emb_dim)  # positional encoding (max_input_seq is the biggest article in input)
    self.decoder_layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

  def forward(self, x, encoder_output, tgt_mask, src_mask):
    B, T = x.shape
    x = self.output_emb(x)
    # Positional embeddings
    positions = torch.arange(0, T, device=x.device)
    positions = positions.unsqueeze(0)  # Shape: [1, T]
    positional_embeddings = self.pos_emb(positions)  # Shape: [1, T, emb_dim]
    x = x + positional_embeddings

    for layer in self.decoder_layers:
      x = layer(x, encoder_output, tgt_mask, src_mask)
    return x

# TODO: lets understand decoder (and modify it)
enc = Encoder()
dec = Decoder()
for input, target in dataloader:
  src_mask = generate_padding_mask(input)
  encoder_output = enc(input, src_mask)
  print('Encoder output:', encoder_output.shape)

  B, T = target.shape
  tgt_ahead_mask = generate_tgt_mask(size=T)
  print(tgt_ahead_mask.shape)
  tgt_padding_mask = generate_padding_mask(target)
  print(tgt_padding_mask.shape)
  tgt_mask = tgt_ahead_mask + tgt_padding_mask
  print('Final tgt mask:', tgt_mask)
  # tgt_mask = tgt_ahead_mask & tgt_padding_mask

  print(tgt_mask.shape)
  decoder_output = dec(target, encoder_output, tgt_mask, src_mask)

  print(decoder_output.shape)
  break

In [None]:
torch.manual_seed(42)
class TransformerEncoderDecoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()
    self.ln_final =  nn.LayerNorm(emb_dim)
    self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False)

  def forward(self, input, src_mask, tgt_mask, targets=None):
    encoder_output = self.encoder(input, src_mask)
    decoder_output = self.decoder(target, encoder_output, tgt_mask, src_mask)
    x = self.ln_final(decoder_output)
    logits = self.lm_head(x)
    loss = None
    if targets is not None:
      # Compute loss
      loss = F.cross_entropy(
          logits.view(-1, logits.size(-1)),
          targets.view(-1),
          ignore_index=tokenizer.pad_token_id
      )
    return logits, loss

  def generate_summary(self, input, src_mask=None, tgt_mask=None, max_length=50):
      """Generate a summary given input text."""

      # Encode the input text
      encoder_output = self.encoder(input, src_mask)  # Shape: [1, seq_len, emb_dim]

      # Start decoding with the <bos> token
      generated = torch.ones(1, 1, device=device).long() * tokenizer.bos_token_id
      print(generated)
      print(encoder_output.shape)

      for _ in range(max_length):
          # Decode the current sequence
          decoder_output = self.decoder(generated, encoder_output, src_mask=src_mask, tgt_mask=tgt_mask)  # Shape: [1, cur_len, emb_dim]

          logits = self.lm_head(self.ln_final(decoder_output))  # Shape: [1, cur_len, vocab_size]
          logits = logits[:, -1, :]  # Take logits of the last token

          # Apply softmax to get probabilities and sample a token
          probs = F.softmax(logits, dim=-1)
          # plt.figure(figsize=(10, 6))
          # plt.plot(probs.detach().numpy()[0])
          # plt.xlabel('Token Index')
          # plt.ylabel('Probability')
          # plt.show()
          # break

          next_token = torch.multinomial(probs, num_samples=1)


          # Append the token to the generated sequence
          generated = torch.cat([generated, next_token], dim=1)  # Shape: [1, cur_len + 1]

          # Stop if <eos> token is generated
          if next_token.item() == tokenizer.eos_token_id:
              break

      return generated

In [None]:
from torch.optim import AdamW
from transformers import get_scheduler
# Hyperparameters

learning_rate = 3e-4
# gradient_accumulation_steps = 1  # Optional, to simulate larger batches
# max_grad_norm = 1.0
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)
print('Total num training steps:', num_training_steps)
print('Total num_epochs:', num_epochs)
print('Dataset times revision:', num_epochs // len(dataloader))

optimizer = AdamW(transformer.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8)
# Initialize the scheduler
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=num_training_steps
)



for epoch in range(num_epochs):
  transformer.train()  # Set model to training mode
  for step, (input, target) in enumerate(dataloader):
    optimizer.zero_grad()  # always start with zero grad!!

    input = input.to(device)
    target = target.to(device)

    # Forward pass
    src_mask = None
    tgt_mask = generate_tgt_mask(max_output_seq)
    logits, loss = transformer(input=input, src_mask=src_mask, tgt_mask=tgt_mask, targets=target)
    # print(logits)

    # Backward pass
    loss.backward()
    norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)  # also from GPT-3 paper
    optimizer.step()  # update parameters to decrease loss

    # Scheduler step
    lr_scheduler.step()


  # Log progress
  if epoch % 2 == 0:
      print(f"Epoch [{epoch}], Loss: {loss.item():.4f}")

In [None]:
torch.manual_seed(42)
# Generate
transformer.eval()

article = news[0]
inp = tokenizer(article, max_length=1024, return_tensors="pt")
summary_ids = transformer.generate_summary(input=inp['input_ids'], max_length=20)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(summary)