In this project, I developed a text summarization model using my own Transformer architecture. This work is mainly for practice; for real-time applications, the model requires further training. I trained it with only one-third of my dataset, specifically using 100,000 samples out of 300,000. Despite this limitation, the model performs reasonably well. If you plan to undertake a similar project, I recommend training your model on the entire dataset of 300,000 samples from the CNN/MAIL dataset to achieve better generalization.

My own Transformer architecture for summarisation task

Adjust the number of layers in the encoder-decoder, the number of multi-head attention heads, the feedforward neural network (FFNN) dimension, and the learning rate based on your dataset size and computational resources. This model falls into the medium complexity range, but for effective generalization, it requires at least 500,000 data samples.

To achieve optimal performance, carefully tune the hyperparameters and model architecture to balance accuracy and efficiency. Consider increasing the model depth and attention heads if your dataset is large enough, while ensuring that training remains feasible within your available hardware.









In [2]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertModel

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)  
        pe = pe.unsqueeze(0) 
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]

###############################################
# 2. Transformer Encoder & Decoder Layers
###############################################
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.activation = nn.ReLU() 

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2, _ = self.self_attn(src, src, src, attn_mask=src_mask,
                                   key_padding_mask=src_key_padding_mask)
        src = self.norm1(src + self.dropout(src2))
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout(src2))
        return src

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.activation = nn.ReLU() 

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                                 key_padding_mask=tgt_key_padding_mask)
        tgt = self.norm1(tgt + self.dropout(tgt2))
        tgt2, _ = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                      key_padding_mask=memory_key_padding_mask)
        tgt = self.norm2(tgt + self.dropout(tgt2))
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = self.norm3(tgt + self.dropout(tgt2))
        return tgt

#################################################
# 3. Custom Transformer Summarizer Model
#################################################
class TransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, d_model=384, nhead=4, d_ff=1024,
                 num_encoder_layers=3, num_decoder_layers=3, dropout=0.1,
                 max_seq_length=256, pre_trained_embeddings=None):
        """
        Args:
          vocab_size: Vocabulary size.
          d_model: Embedding dimension (should match BERT’s hidden size if using its embeddings).
          nhead: Number of attention heads.
          d_ff: Feed-forward hidden dimension.
          num_encoder_layers: Number of encoder layers.
          num_decoder_layers: Number of decoder layers.
          dropout: Dropout rate.
          max_seq_length: Maximum sequence length (for positional encoding).
          pre_trained_embeddings: Tensor for initializing embedding layer.
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        if pre_trained_embeddings is not None:
            if pre_trained_embeddings.shape == (vocab_size, d_model):
                self.embedding.weight.data.copy_(pre_trained_embeddings)
            else:
                print("Warning: Pre-trained embedding dimensions do not match. Using random init.")
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_length)
        self.pos_decoder = PositionalEncoding(d_model, max_len=max_seq_length)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, nhead, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, nhead, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encode(self, src, src_mask=None, src_key_padding_mask=None):
        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_emb = self.pos_encoder(src_emb).transpose(0, 1)  
        for layer in self.encoder_layers:
            src_emb = layer(src_emb, src_mask, src_key_padding_mask)
        return src_emb

    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None,
               tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_decoder(tgt_emb).transpose(0, 1)  
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, memory, tgt_mask, memory_mask,
                            tgt_key_padding_mask, memory_key_padding_mask)
        return tgt_emb

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        """
        Args:
          src: Source tokens (batch_size, src_seq_len)
          tgt: Target tokens (batch_size, tgt_seq_len)
        Returns:
          Logits over vocabulary for each target token.
          Shape: (tgt_seq_len, batch_size, vocab_size)
        """
        memory = self.encode(src, src_mask, src_key_padding_mask)
        decoder_output = self.decode(tgt, memory, tgt_mask, None,
                                     tgt_key_padding_mask, src_key_padding_mask)
        return self.output_layer(decoder_output)

    def generate(self, src, src_mask=None, src_key_padding_mask=None,
                 max_length=100, beam_width=5, start_token_id=101, end_token_id=102, device="cpu"):
        """
        Generate summary using beam search.
        Args:
          src: Source tensor (1, src_seq_len)
          start_token_id: Start token (e.g., [CLS] for BERT is 101).
          end_token_id: End token (e.g., [SEP] for BERT is 102).
        Returns:
          List of token IDs representing the generated summary.
        """
        self.eval()
        with torch.no_grad():
            memory = self.encode(src, src_mask, src_key_padding_mask)
            beams = [([start_token_id], 0.0)]
            for _ in range(max_length):
                new_beams = []
                for seq, score in beams:
                    if seq[-1] == end_token_id:
                        new_beams.append((seq, score))
                        continue
                    tgt_seq = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
                    seq_len = tgt_seq.size(1)
                    tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device)
                    decoder_output = self.decode(tgt_seq, memory, tgt_mask=tgt_mask)
                    last_output = decoder_output[-1, :, :] 
                    logits = self.output_layer(last_output)  
                    log_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
                    top_log_probs, top_indices = torch.topk(log_probs, beam_width)
                    for log_p, token_id in zip(top_log_probs.tolist(), top_indices.tolist()):
                        new_seq = seq + [token_id]
                        new_score = score + log_p
                        new_beams.append((new_seq, new_score))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
                if all(seq[-1] == end_token_id for seq, _ in beams):
                    break
            best_sequence = beams[0][0]
            return best_sequence

#################################################
# 4. Custom Dataset Class for Summarization Data
#################################################
class SummarizationDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_src_len=512, max_tgt_len=128):
        """
        Args:
          csv_file: Path to CSV file with 'article' and 'highlights' columns.
          tokenizer: Pre-trained tokenizer (BERT in our case).
          max_src_len: Maximum token length for the source article.
          max_tgt_len: Maximum token length for the target summary.
        """
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        article = str(row["article"])
        summary = str(row["highlights"])
        src_encoding = self.tokenizer(article, max_length=self.max_src_len,
                                      truncation=True, padding="max_length", return_tensors="pt")
        tgt_encoding = self.tokenizer(summary, max_length=self.max_tgt_len,
                                      truncation=True, padding="max_length", return_tensors="pt")
        src_ids = src_encoding["input_ids"].squeeze(0)
        tgt_ids = tgt_encoding["input_ids"].squeeze(0)
        return {"src": src_ids, "tgt": tgt_ids}

Model Training

In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Loading BERT tokenizer and model for embeddings...")
    bert_model_name = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    bert_model = BertModel.from_pretrained(bert_model_name)
    bert_model.eval() 
    bert_embeddings = bert_model.embeddings.word_embeddings.weight.data.clone()

    vocab_size = tokenizer.vocab_size  
    d_model = bert_embeddings.size(1)   
    nhead = 4
    d_ff = 1024
    num_encoder_layers = 3
    num_decoder_layers = 3
    dropout = 0.1
    max_seq_length = 256  
    batch_size = 64    
    num_epochs = 10    

    print("Initializing custom Transformer summarization model...")
    model = TransformerSummarizer(
        vocab_size=vocab_size,
        d_model=d_model,
        nhead=nhead,
        d_ff=d_ff,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dropout=dropout,
        max_seq_length=max_seq_length,
        pre_trained_embeddings=bert_embeddings
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    train_dataset = SummarizationDataset("/content/part1.csv", tokenizer, max_src_len=max_seq_length, max_tgt_len=128)
    val_dataset   = SummarizationDataset("/content/validation.csv", tokenizer, max_src_len=max_seq_length, max_tgt_len=128)
    test_dataset  = SummarizationDataset("/content/test.csv", tokenizer, max_src_len=max_seq_length, max_tgt_len=128)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False)  
    ##########################################################
    # 6. Training and Validation Loop
    ##########################################################
    best_val_loss = float("inf")
    for epoch in range(1, num_epochs+1):
        model.train()
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
            src = batch["src"].to(device)  
            tgt = batch["tgt"].to(device)  
            tgt_input = tgt[:, :-1]
            tgt_labels = tgt[:, 1:]
            tgt_seq_len = tgt_input.size(1)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len).to(device)

            optimizer.zero_grad()
            output_logits = model(src, tgt_input, tgt_mask=tgt_mask)
            loss = criterion(output_logits.view(-1, vocab_size), tgt_labels.transpose(0, 1).reshape(-1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch} Train Loss: {avg_train_loss:.4f}")

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch}"):
                src = batch["src"].to(device)
                tgt = batch["tgt"].to(device)
                tgt_input = tgt[:, :-1]
                tgt_labels = tgt[:, 1:]
                tgt_seq_len = tgt_input.size(1)
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len).to(device)
                output_logits = model(src, tgt_input, tgt_mask=tgt_mask)
                loss = criterion(output_logits.view(-1, vocab_size), tgt_labels.transpose(0, 1).reshape(-1))
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch} Validation Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = "best_transformer_summarizer.pt"
            torch.save({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "model_config": {
                    "vocab_size": vocab_size,
                    "d_model": d_model,
                    "nhead": nhead,
                    "d_ff": d_ff,
                    "num_encoder_layers": num_encoder_layers,
                    "num_decoder_layers": num_decoder_layers,
                    "dropout": dropout,
                    "max_seq_length": max_seq_length,
                    "bert_model_name": bert_model_name
                }
            }, best_model_path)
            print(f"Saved best model at epoch {epoch} to {best_model_path}")

    ##########################################################
    # 7. Testing / Generation on Test Dataset
    ##########################################################
    print("Generating summaries on test dataset...")
    model.eval()
    results = []
    for i, batch in enumerate(tqdm(test_loader, desc="Testing")):
        src = batch["src"].to(device)  
        generated_ids = model.generate(src, max_length=50, beam_width=5,
                                       start_token_id=tokenizer.cls_token_id,
                                       end_token_id=tokenizer.sep_token_id,
                                       device=device)
        generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True)
        results.append(generated_summary)

    # Save test results to a CSV file.
    test_results_df = pd.DataFrame({"generated_summary": results})
    test_results_df.to_csv("test_generated_summaries.csv", index=False)
    print("Test summaries saved to test_generated_summaries.csv")

if __name__ == "__main__":
    main()