In [None]:
!pip install datasets



In [None]:
# Install required packages
import sys
import subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'datasets', 'transformers','wandb'])

import datasets
import pandas as pd
import numpy as np
import torch
import random
from tqdm import tqdm
import wandb

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Initialize wandb
wandb.init(project="bilstm-summarization", name="bilstm-seq2seq")

# Load the CNN/DailyMail dataset
cnn_dataset = datasets.load_dataset("abisee/cnn_dailymail", "3.0.0")

# Convert to pandas DataFrame for easier manipulation
train_df = pd.DataFrame(cnn_dataset["train"])
val_df = pd.DataFrame(cnn_dataset["validation"])
test_df = pd.DataFrame(cnn_dataset["test"])

# Sample a smaller portion of the training data for faster processing
sample_size = int(len(train_df) * 0.001)
train_sample = train_df.sample(n=sample_size, random_state=42)

print(f"Full training set size: {len(train_df)}")
print(f"Sample size: {len(train_sample)}")

# Log dataset info to wandb
wandb.config.update({
    "dataset": "CNN/DailyMail",
    "full_train_size": len(train_df),
    "sample_size": len(train_sample),
    "val_size": len(val_df),
    "test_size": len(test_df)
})

Full training set size: 287113
Sample size: 287


In [None]:
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize, word_tokenize

# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')
def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()

    # Remove special characters and numbers
    text = re.sub(r'[^a-zA-Z\s]', '', text)

    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text

def tokenize_document(document):
    # Split document into sentences
    sentences = sent_tokenize(document)

    # Preprocess each sentence
    processed_sentences = [preprocess_text(sentence) for sentence in sentences]

    # Remove empty sentences
    processed_sentences = [s for s in processed_sentences if s.strip()]

    return processed_sentences

# Apply preprocessing to the sampled data
train_sample['processed_article'] = train_sample['article'].apply(tokenize_document)
train_sample['processed_highlights'] = train_sample['highlights'].apply(tokenize_document)

# Create labels for extractive summarization (1 if sentence is in highlights, 0 otherwise)
def create_extractive_labels(article_sentences, highlight_sentences):
    labels = []
    for sentence in article_sentences:
        # Check if this sentence is similar to any highlight sentence
        is_in_highlights = any(
            similarity_score(sentence, highlight) > 0.7
            for highlight in highlight_sentences
        )
        labels.append(1 if is_in_highlights else 0)
    return labels

def similarity_score(sent1, sent2):
    # Simple word overlap similarity
    words1 = set(word_tokenize(sent1))
    words2 = set(word_tokenize(sent2))

    if not words1 or not words2:
        return 0

    overlap = len(words1.intersection(words2))
    return overlap / max(len(words1), len(words2))

# Create extractive labels
train_sample['extractive_labels'] = [
    create_extractive_labels(article, highlight)
    for article, highlight in zip(train_sample['processed_article'], train_sample['processed_highlights'])
]




[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [None]:
!pip install rouge-score




In [None]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, TrainingArguments, Trainer, EarlyStoppingCallback,TrainerCallback
import torch.nn as nn
import torch.nn.functional as F

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [None]:
# Define BiLSTM model architecture
class BiLSTMSeq2Seq(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.num_layers = num_layers

        # Encoder (BiLSTM)
        self.encoder = nn.LSTM(embedding_dim, hidden_dim,
                             bidirectional=True, batch_first=True, num_layers=num_layers,
                             dropout=dropout if num_layers > 1 else 0)

        # Decoder (LSTM with attention)
        self.decoder = nn.LSTM(embedding_dim + hidden_dim*2, hidden_dim*2,  # Attention concatenation
                             batch_first=True, dropout=dropout, num_layers=1)

        # Attention mechanism
        self.attention = nn.Linear(hidden_dim*2 + hidden_dim*2, hidden_dim*2)
        self.v = nn.Linear(hidden_dim*2, 1, bias=False)

        # Final projection layer
        self.fc = nn.Linear(hidden_dim*2, vocab_size)

        # Dropout layers
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, trg=None, max_len=128, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)

        # Encoder Forward Pass
        embedded = self.dropout(self.embedding(src))
        encoder_outputs, (hidden, cell) = self.encoder(embedded)

        # Prepare decoder initial states
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1).unsqueeze(0)
        cell = torch.cat((cell[-2,:,:], cell[-1,:,:]), dim=1).unsqueeze(0)

        # Decoder Setup
        if trg is None:
            trg = torch.zeros((batch_size, max_len), dtype=torch.long, device=src.device)
            trg[:,0] = 1  # Start with SOS token

        decoder_input = self.embedding(trg[:,0].unsqueeze(1))
        outputs = torch.zeros(max_len, batch_size, self.fc.out_features, device=src.device)

        # Decoding Loop
        for t in range(1, max_len):
            # Attention Calculation
            energy = torch.tanh(self.attention(torch.cat((
                hidden.repeat(encoder_outputs.size(1), 1, 1).permute(1,0,2),
                encoder_outputs
            ), dim=2)))

            attention = F.softmax(self.v(energy).squeeze(2), dim=1)
            context = torch.bmm(attention.unsqueeze(1), encoder_outputs)

            # Decoder Step
            decoder_output, (hidden, cell) = self.decoder(
                torch.cat((decoder_input, context), dim=2),
                (hidden, cell)
            )

            # Project to vocabulary space
            output = self.fc(decoder_output.squeeze(1))
            outputs[t] = output

            # Teacher Forcing
            use_teacher_forcing = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)

            decoder_input = self.embedding(trg[:,t].unsqueeze(1) if use_teacher_forcing else top1.unsqueeze(1))
            decoder_input = self.dropout(decoder_input)

        return outputs.permute(1, 0, 2)

    def generate(self, src, max_len=128, temperature=1.0):
        with torch.no_grad():
            # Encoder forward pass
            encoder_outputs, (hidden, cell) = self.encoder(self.embedding(src))

            # Prepare decoder initial states
            hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1).unsqueeze(0)
            cell = torch.cat((cell[-2,:,:], cell[-1,:,:]), dim=1).unsqueeze(0)

            outputs = []
            decoder_input = torch.tensor([[1]], device=src.device)  # SOS token

            for _ in range(max_len):
                decoder_emb = self.embedding(decoder_input)

                # Attention
                energy = torch.tanh(self.attention(torch.cat((
                    hidden.repeat(encoder_outputs.size(1), 1, 1).permute(1,0,2),
                    encoder_outputs
                ), dim=2)))

                attention = F.softmax(self.v(energy).squeeze(2), dim=1)
                context = torch.bmm(attention.unsqueeze(1), encoder_outputs)

                # Decoder step
                decoder_output, (hidden, cell) = self.decoder(
                    torch.cat((decoder_emb, context), dim=2),
                    (hidden, cell)
                )

                # Output projection
                logits = self.fc(decoder_output.squeeze(1)) / temperature
                probabilities = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probabilities, 1)

                if next_token.item() == 2:  # EOS token
                    break

                outputs.append(next_token.item())
                decoder_input = next_token

            return outputs

In [None]:
# Create a wrapper model compatible with HuggingFace Trainer
class BiLSTMWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.model = base_model

    def forward(self, input_ids=None, labels=None, attention_mask=None, **kwargs):
        # Forward pass through the model
        if labels is not None:
            # Training mode with labels
            outputs = self.model(src=input_ids, trg=labels)

            # Calculate loss - CrossEntropyLoss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = 0

            # Calculate loss for each position in the sequence
            for t in range(1, outputs.size(1)):
                loss += loss_fct(outputs[:, t, :], labels[:, t])

            # Average loss across positions
            loss = loss / (outputs.size(1) - 1)

            return {"loss": loss, "logits": outputs}
        else:
            # Inference mode
            return {"logits": self.model(src=input_ids)}

In [None]:
# Split into train and validation sets
train_df, val_df = train_test_split(train_sample, test_size=0.1, random_state=42)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

from datasets import Dataset, DatasetDict

# Convert Pandas DataFrames to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_dict({
        "input_text": train_df["article"].tolist(),
        "target_text": train_df["highlights"].tolist(),
    }),
    "validation": Dataset.from_dict({
        "input_text": val_df["article"].tolist(),
        "target_text": val_df["highlights"].tolist(),
    })
})

# Extract train and validation datasets
train_dataset = dataset["train"]
val_dataset = dataset["validation"]

# Define tokenization function
def tokenize_function(batch):
    inputs = tokenizer(batch["input_text"], padding="max_length", truncation=True, max_length=512)
    targets = tokenizer(batch["target_text"], padding="max_length", truncation=True, max_length=128)

    inputs["labels"] = targets["input_ids"]
    return inputs

# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)

# Initialize the BiLSTM model
embedding_dim = 256
hidden_dim = 512
num_layers = 2
dropout = 0.2
base_model = BiLSTMSeq2Seq(len(tokenizer), embedding_dim, hidden_dim, num_layers=num_layers, dropout=dropout).to(device)
model = BiLSTMWrapper(base_model)

# Log model hyperparameters to wandb
wandb.config.update({
    "model_type": "BiLSTM Seq2Seq with Attention",
    "embedding_dim": embedding_dim,
    "hidden_dim": hidden_dim,
    "num_layers": num_layers,
    "dropout": dropout,
    "vocab_size": len(tokenizer)
})

# Setup optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define custom data collator to handle the batch preparation
from transformers import DataCollatorWithPadding

class CustomDataCollator(DataCollatorWithPadding):
    def __init__(self, tokenizer, padding=True, max_length=None):
        super().__init__(tokenizer=tokenizer, padding=padding, max_length=max_length)

    def __call__(self, features):
        batch = super().__call__(features)
        # DO NOT move tensors to device - Trainer will handle this
        return batch

# Define training arguments with wandb integration
training_args = TrainingArguments(
    output_dir="./biLSTMS_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=20,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    report_to="wandb",  # Enable wandb reporting
    run_name="bilstm-seq2seq",
    dataloader_pin_memory=False,
)

# Custom callback to log example predictions
class LogPredictionCallback(TrainerCallback):
    def __init__(self, model, tokenizer, eval_dataset, num_examples=3):
        self.model = model
        self.tokenizer = tokenizer
        self.eval_dataset = eval_dataset
        self.num_examples = num_examples

    def on_evaluate(self, args, state, control,metrics=None, **kwargs):
        # Get a few examples from evaluation dataset
        indices = random.sample(range(len(self.eval_dataset)), min(self.num_examples, len(self.eval_dataset)))
        examples = [self.eval_dataset[i] for i in indices]

        for i, example in enumerate(examples):
            input_text = self.tokenizer.decode(example['input_ids'], skip_special_tokens=True)
            reference = self.tokenizer.decode(example['labels'], skip_special_tokens=True)

            # Generate summary
            input_ids = torch.tensor([example['input_ids']]).to(device)
            with torch.no_grad():
                prediction_ids = self.model.generate(input_ids, max_len=128)
                prediction = self.tokenizer.decode(prediction_ids, skip_special_tokens=True)

            # Log to wandb
            wandb.log({
                f"example_{i}/input": wandb.Html(input_text[:500] + "..."),
                f"example_{i}/reference": wandb.Html(reference),
                f"example_{i}/prediction": wandb.Html(prediction)
            })

        return control

# Initialize the early stopping callback
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=3)

# Define Trainer with callbacks
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=CustomDataCollator(tokenizer),
    callbacks=[
        early_stopping_callback,
        LogPredictionCallback(base_model, tokenizer, tokenized_val)
    ]
)

# Train the model
trainer.train()

# Save the model
torch.save({
    'model_state_dict': base_model.state_dict(),
    'vocab_size': len(tokenizer),
    'embedding_dim': embedding_dim,
    'hidden_dim': hidden_dim,
    'num_layers': num_layers
}, "biLSTMs_model.pth")

# Log model artifact to wandb
wandb.save("biLSTMs_model.pth")

Map:   0%|          | 0/258 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,7.1775,5.994132
2,4.6425,4.882666
3,4.357,4.398864
4,4.2404,4.239155
5,3.6068,4.255985
6,3.9504,4.153348
7,3.8364,4.150993
8,3.5062,4.147541
9,3.6517,4.160427
10,3.6678,4.148197


['/content/wandb/run-20250331_142643-ywnqerul/files/biLSTMs_model.pth']

In [None]:
class BiLSTMSummarizer:
    def __init__(self, model_path, tokenizer, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.tokenizer = tokenizer

        # Load model configuration
        checkpoint = torch.load(model_path, map_location=self.device)

        # Initialize model with saved parameters
        self.model = BiLSTMSeq2Seq(
            vocab_size=checkpoint['vocab_size'],
            embedding_dim=checkpoint['embedding_dim'],
            hidden_dim=checkpoint['hidden_dim'],
            num_layers=checkpoint.get('num_layers', 1)
        ).to(self.device)

        # Load weights
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

    def generate_summary(self, input_text, max_length=128):
        """Generate summary using BiLSTM model"""
        inputs = self.tokenizer(
            input_text,
            return_tensors='pt',
            max_length=512,
            truncation=True
        ).input_ids.to(self.device)

        with torch.no_grad():
            summary_ids = self.model.generate(inputs, max_len=max_length)
            return self.tokenizer.decode(summary_ids, skip_special_tokens=True)

    def evaluate(self, test_df, text_col='article', target_col='highlights'):
        """Evaluate BiLSTM performance using ROUGE metrics"""
        from rouge_score import rouge_scorer

        generated_summaries = []
        reference_summaries = []

        for _, row in test_df.iterrows():
            input_text = row[text_col]
            generated = self.generate_summary(input_text)
            generated_summaries.append(generated)
            reference_summaries.append(row[target_col])

        return self._calculate_rouge(generated_summaries, reference_summaries)

    def _calculate_rouge(self, generated, references):
        """Calculate ROUGE scores"""
        from rouge_score import rouge_scorer

        scorer = rouge_scorer.RougeScorer(
            ['rouge1', 'rouge2', 'rougeL'],
            use_stemmer=True
        )

        scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

        for gen, ref in zip(generated, references):
            score = scorer.score(ref, gen)
            scores['rouge1'].append(score['rouge1'].fmeasure)
            scores['rouge2'].append(score['rouge2'].fmeasure)
            scores['rougeL'].append(score['rougeL'].fmeasure)

        return {
            metric: sum(values)/len(values) if values else 0
            for metric, values in scores.items()
        }

In [2]:
# Install rouge package
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'rouge-score'])
from rouge_score import rouge_scorer

# Initialize BiLSTM Summarizer
bilstm_summarizer = BiLSTMSummarizer(
    model_path="biLSTMs_model.pth",
    tokenizer=tokenizer
)

# Test on a sample article
sample_article = test_df.iloc[0]['article']
generated_summary = bilstm_summarizer.generate_summary(sample_article)
actual_summary = test_df.iloc[0]['highlights']

print("Generated Summary:")
print(generated_summary)
print("\nActual Summary:")
print(actual_summary)

# Evaluate on test set
test_sample = test_df.head(10)
rouge_scores = bilstm_summarizer.evaluate(test_sample)

# Display results
print("\nBiLSTM ROUGE Scores:")
print(f"ROUGE-1: {rouge_scores['rouge1']:.4f}")
print(f"ROUGE-2: {rouge_scores['rouge2']:.4f}")
print(f"ROUGE-L: {rouge_scores['rougeL']:.4f}")

# Log final evaluation metrics to wandb
wandb.log({
    "final_rouge1": rouge_scores['rouge1'],
    "final_rouge2": rouge_scores['rouge2'],
    "final_rougeL": rouge_scores['rougeL']
})

# Create a table for the test examples
test_table = wandb.Table(columns=["Article", "Reference", "Generated"])

# Add a few examples to the table
for i in range(min(5, len(test_sample))):
    article = test_sample.iloc[i]['article']
    reference = test_sample.iloc[i]['highlights']
    generated = bilstm_summarizer.generate_summary(article)
    test_table.add_data(article[:300] + "...", reference, generated)

# Log the table
wandb.log({"test_examples": test_table})

# Finish the wandb run
wandb.finish()



Generated Summary:
 wish suchewitness capt women to goals. of railing inHarry northern Open toiggs . .The a- four Hughes's phenomena . .he impression
 Beat . length

Actual Summary:
Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

BiLSTM ROUGE Scores:
ROUGE-1: 17.12
ROUGE-2: 16.24
ROUGE-L: 23.03


Run history:

eval/loss	█▄▂▁▁▁▁▁▁▁▁
eval/runtime	██▇▁▇▃▇▁▇▃▅
eval/samples_per_second	▁▂▂█▂▆▂█▂▆▅
eval/steps_per_second	▁▁▂█▂▆▂█▂▆▄
final_rouge1	▁
final_rouge2	▁
final_rougeL	▁
train/epoch	▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step	▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇█████
train/grad_norm	▃▅▆▁▁▂▁▅▂▃▂▃▂▂▄▆▂▃▂▂▃▂▁▁█▂▅▃▂▂▁█▂▁▃▇▃▂▁▂
train/learning_rate	██▇▇▇▇▇▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁
train/loss	█▇▇▆▅▃▂▂▂▂▁▁▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁

Run summary:

eval/loss		4.15857
eval/runtime		4.2908
e