In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files
uploaded=files.upload()
print(uploaded)

In [None]:
!mv /content/WikiSumDataset.jsonl /content/drive/MyDrive/WikiSumDataset.jsonl

In [None]:
!pip install -U "transformers[torch]" "datasets" "evaluate" "nltk"

In [None]:
!pip install rouge_score

In [6]:
import zipfile
import os

zip_file_path = '/content/WikiSumDataset.jsonl.zip'
extract_path = '/content/'

# Unzip the file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Extracted {zip_file_path} to {extract_path}")

# Verify the file exists after extraction
extracted_file_path = os.path.join(extract_path, 'WikiSumDataset.jsonl')
if os.path.exists(extracted_file_path):
    print(f"'{extracted_file_path}' found after extraction.")
else:
    print(f"Error: '{extracted_file_path}' not found after extraction.")

Extracted /content/WikiSumDataset.jsonl.zip to /content/
'/content/WikiSumDataset.jsonl' found after extraction.


In [12]:
import warnings
warnings.filterwarnings('ignore')

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    BartForConditionalGeneration,
    BartModel,
    BartConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    BartPretrainedModel
)
from transformers.modeling_outputs import BaseModelOutput
import evaluate
import time
import json # Import the json library


print("All libraries imported successfully.")

# ==============================================================================
# 2. DATA LOADING & PREPARATION
# ==============================================================================
def load_jsonl(file_path, nrows=None):
    """Loads a JSONL file into a Pandas DataFrame."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            if nrows is not None and idx >= nrows:
                break
            data.append(json.loads(line))
    return pd.DataFrame(data)

NUM_SAMPLES = 2000
file_path = '/content/WikiSumDataset.jsonl'  # Corrected path to the extracted file
df = load_jsonl(file_path, nrows=NUM_SAMPLES)
print(f"Loaded {len(df)} samples from the WikiSum dataset.")

# 3. PARALLEL HIERARCHICAL TRANSFORMER (PHT) MODEL DEFINITION
# ==============================================================================
class HierarchicalEncoder(nn.Module):
    def __init__(self, config: BartConfig, segment_size=128):
        super().__init__()
        self.config = config
        self.segment_size = segment_size

        self.word_level_encoder = BartModel(config).encoder

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            activation='relu',
            batch_first=True
        )
        self.paragraph_level_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

    @property
    def embed_tokens(self):
        return self.word_level_encoder.embed_tokens

    @embed_tokens.setter
    def embed_tokens(self, value):
        self.word_level_encoder.embed_tokens = value

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        word_level_outputs = self.word_level_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        word_embeddings = word_level_outputs.last_hidden_state
        batch_size, seq_len, hidden_size = word_embeddings.shape

        effective_seq_len = (seq_len // self.segment_size) * self.segment_size

        if effective_seq_len == 0:
            return BaseModelOutput(
                last_hidden_state=word_embeddings,
                hidden_states=word_level_outputs.hidden_states,
                attentions=word_level_outputs.attentions,
            )

        truncated_embeddings = word_embeddings[:, :effective_seq_len, :]
        num_segments = effective_seq_len // self.segment_size

        segmented_embeddings = truncated_embeddings.reshape(
            batch_size, num_segments, self.segment_size, hidden_size
        )
        paragraph_representations = segmented_embeddings.mean(dim=2)

        paragraph_level_outputs = self.paragraph_level_encoder(paragraph_representations)

        upsampled_paragraph_outputs = paragraph_level_outputs.repeat_interleave(
            self.segment_size, dim=1
        )

        combined_embeddings = word_embeddings.clone()
        combined_embeddings[:, :effective_seq_len, :] += upsampled_paragraph_outputs

        return BaseModelOutput(
            last_hidden_state=combined_embeddings,
            hidden_states=word_level_outputs.hidden_states,
            attentions=word_level_outputs.attentions,
        )

class PHTModel(BartForConditionalGeneration):
    def __init__(self, config, segment_size=128):
        super().__init__(config)
        self.model.encoder = HierarchicalEncoder(config, segment_size=segment_size)
        self.tie_weights()

def create_pht_model(model_name="facebook/bart-base", segment_size=128):
    base_model = BartForConditionalGeneration.from_pretrained(model_name)
    config = base_model.config
    pht_model = PHTModel(config, segment_size=segment_size)

    pht_model.model.encoder.word_level_encoder.load_state_dict(base_model.model.encoder.state_dict())
    pht_model.model.decoder.load_state_dict(base_model.model.decoder.state_dict())
    pht_model.lm_head.load_state_dict(base_model.lm_head.state_dict())

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    print("Parallel Hierarchical Transformer (PHT) model created successfully.")
    print(f"✓ Hierarchical encoder with segment_size={segment_size}")
    print("✓ Pretrained weights loaded and properly tied")

    return pht_model, tokenizer

# ==============================================================================
# 4. DATASET AND DATALOADER
# ==============================================================================
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128
MODEL_NAME = "facebook/bart-base"

model, tokenizer = create_pht_model(model_name=MODEL_NAME, segment_size=128)

class WikiSumDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.inputs = df['article'].tolist()
        self.targets = df['summary'].tolist()
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        article = str(self.inputs[idx])
        summary = str(self.targets[idx])
        model_inputs = self.tokenizer(
            article,
            max_length=MAX_INPUT_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        labels = self.tokenizer(
            summary,
            max_length=MAX_TARGET_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        labels_ids = labels["input_ids"].clone()
        labels_ids[labels_ids == self.tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels_ids
        return {k: v.squeeze(0) for k, v in model_inputs.items()}

train_df = df.sample(frac=0.9, random_state=42)
val_df = df.drop(train_df.index)
train_dataset = WikiSumDataset(train_df, tokenizer)
val_dataset = WikiSumDataset(val_df, tokenizer)

print(f"Created {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")

# ==============================================================================
# 5. TRAINING SETUP AND EXECUTION
# ==============================================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model moved to {device}.")

training_args = Seq2SeqTrainingArguments(
    output_dir='./results_pht',
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir='./logs_pht',
    logging_steps=50,
    save_steps=500,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LENGTH,
    generation_num_beams=4,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

print("Starting training with the PHT model...")
print("🔄 Training PHT on WikiSum dataset...")
train_result = trainer.train()
print("✅ PHT training finished!")

trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

print("\n******** PHT Training Results ********")
print(f"Final training loss: {metrics.get('train_loss', 'N/A'):.4f}")
print(f"Training runtime: {metrics.get('train_runtime', 'N/A'):.1f} seconds")
print(f"Samples per second: {metrics.get('train_samples_per_second', 'N/A'):.3f}")

# ==============================================================================
# 6. PHT INFERENCE AND ROUGE EVALUATION
# ==============================================================================
print("\n🔍 Starting PHT model evaluation...")
rouge = evaluate.load('rouge')
N_EVAL_SAMPLES = min(100, len(val_df))

eval_articles = val_df['article'].tolist()[:N_EVAL_SAMPLES]
gold_summaries = val_df['summary'].tolist()[:N_EVAL_SAMPLES]
pred_summaries = []

model.eval()
start_time = time.time()

print(f"Generating summaries for {N_EVAL_SAMPLES} articles with PHT model...")

for i, article in enumerate(eval_articles):
    inputs = tokenizer(
        article,
        padding="max_length",
        truncation=True,
        max_length=MAX_INPUT_LENGTH,
        return_tensors="pt"
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)

    with torch.no_grad():
        summary_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            num_beams=4,
            max_length=MAX_TARGET_LENGTH + 2,
            early_stopping=True,
            no_repeat_ngram_size=3,
            do_sample=False
        )

    pred_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    pred_summaries.append(pred_summary)

    elapsed = time.time() - start_time
    avg_time = elapsed / (i + 1)
    if (i + 1) % 10 == 0 or i == N_EVAL_SAMPLES - 1:
        eta = avg_time * (N_EVAL_SAMPLES - (i + 1))
        print(f"Progress: {i+1}/{N_EVAL_SAMPLES} | Avg: {avg_time:.2f}s/sample | ETA: {eta/60:.1f} min")

results = rouge.compute(
    predictions=pred_summaries,
    references=gold_summaries,
    use_stemmer=True
)
end_time = time.time()
total_time = (end_time - start_time) / 60

print("\n" + "="*60)
print("🎯 PHT MODEL ROUGE EVALUATION RESULTS")
print("="*60)
for key, value in results.items():
    print(f"{key.upper():<12}: {value:.4f}")
print(f"\nEvaluation completed in {total_time:.2f} minutes")
print(f"Average time per sample: {(total_time*60)/N_EVAL_SAMPLES:.2f} seconds")

print("\n" + "="*60)
print("📄 SAMPLE PHT GENERATIONS")
print("="*60)

for i in range(min(5, N_EVAL_SAMPLES)):
    print(f"\n--- PHT SAMPLE {i+1} ---")
    print(f"📖 SOURCE ARTICLE (truncated):")
    print(f"{eval_articles[i][:350]}[...]")
    print(f"\n🎯 REFERENCE SUMMARY:")
    print(f"{gold_summaries[i]}")
    print(f"\n🤖 PHT MODEL SUMMARY:")
    print(f"{pred_summaries[i]}")
    print("-" * 50)

print("\n" + "="*60)
print("✅ PHT IMPLEMENTATION SUCCESSFULLY COMPLETED!")
print("="*60)
print("🔬 Key PHT Features Implemented:")
print("  ✓ Hierarchical word and paragraph-level encoding")
print("  ✓ Segment-based parallel processing (segment_size=128)")
print("  ✓ Combined representations via residual connections")
print("  ✓ Standard BART decoder for high-quality generation")
print("  ✓ Trained and evaluated on WikiSum dataset")
print("\n📊 Results ready for comparison with PHT research paper!")

All libraries imported successfully.
Loaded 2000 samples from the WikiSum dataset.
Parallel Hierarchical Transformer (PHT) model created successfully.
✓ Hierarchical encoder with segment_size=128
✓ Pretrained weights loaded and properly tied
Created 1800 training samples and 200 validation samples.
Model moved to cuda.
Starting training with the PHT model...
🔄 Training PHT on WikiSum dataset...


Step,Training Loss
50,7.6495
100,6.656
150,6.5164
200,6.5048
250,6.4445
300,6.3476
350,6.3227
400,6.2858
450,6.3041
500,6.1901


✅ PHT training finished!
***** train metrics *****
  epoch                    =        1.0
  total_flos               =   567867GF
  train_loss               =     6.3419
  train_runtime            = 0:04:10.24
  train_samples_per_second =      7.193
  train_steps_per_second   =      3.597

******** PHT Training Results ********
Final training loss: 6.3419
Training runtime: 250.2 seconds
Samples per second: 7.193

🔍 Starting PHT model evaluation...
Generating summaries for 100 articles with PHT model...
Progress: 10/100 | Avg: 1.39s/sample | ETA: 2.1 min
Progress: 20/100 | Avg: 1.57s/sample | ETA: 2.1 min
Progress: 30/100 | Avg: 1.51s/sample | ETA: 1.8 min
Progress: 40/100 | Avg: 1.46s/sample | ETA: 1.5 min
Progress: 50/100 | Avg: 1.41s/sample | ETA: 1.2 min
Progress: 60/100 | Avg: 1.37s/sample | ETA: 0.9 min
Progress: 70/100 | Avg: 1.37s/sample | ETA: 0.7 min
Progress: 80/100 | Avg: 1.34s/sample | ETA: 0.4 min
Progress: 90/100 | Avg: 1.31s/sample | ETA: 0.2 min
Progress: 100/100 | Avg

In [13]:
print("\n" + "="*60)
print("🎯 PHT MODEL ROUGE EVALUATION RESULTS")
print("="*60)
for key, value in results.items():
    print(f"{key.upper():<12}: {value:.4f}")
print("="*60)


🎯 PHT MODEL ROUGE EVALUATION RESULTS
ROUGE1      : 0.2002
ROUGE2      : 0.0143
ROUGEL      : 0.1422
ROUGELSUM   : 0.1421
