In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import BartTokenizer
from torchmetrics.text.rouge import ROUGEScore

# Enable cuDNN benchmarking for performance.
torch.backends.cudnn.benchmark = True

# --------------------- Data Preparation ---------------------
class SummarizationDataset(Dataset):
    def __init__(self, dataset_split, tokenizer, max_input_length=256, max_target_length=128):
        self.dataset = dataset_split
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        source = item['article']
        target = item['highlights']
        source_enc = self.tokenizer(
            source, truncation=True, padding='max_length',
            max_length=self.max_input_length, return_tensors="pt"
        )
        target_enc = self.tokenizer(
            target, truncation=True, padding='max_length',
            max_length=self.max_target_length, return_tensors="pt"
        )
        return {
            'input_ids': source_enc.input_ids.squeeze(0),
            'attention_mask': source_enc.attention_mask.squeeze(0),
            'target_ids': target_enc.input_ids.squeeze(0)
        }

def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    target_ids = torch.stack([item['target_ids'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'target_ids': target_ids}

# --------------------- Model Definition ---------------------
class Seq2SeqGRUModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx, dropout=0.3):
        super(Seq2SeqGRUModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout)
        self.encoder = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.decoder = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, src, trg):
        # Encode source sequence
        embedded_src = self.embedding(src)
        embedded_src = self.dropout(embedded_src)
        encoder_outputs, hidden = self.encoder(embedded_src)
        # Decode target sequence using teacher forcing
        embedded_trg = self.embedding(trg)
        embedded_trg = self.dropout(embedded_trg)
        decoder_outputs, _ = self.decoder(embedded_trg, hidden)
        decoder_outputs = self.dropout(decoder_outputs)
        output = self.fc(decoder_outputs)
        return output

    def generate(self, src, sos_token, eos_token, max_len=128, beam_width=3):
        self.eval()
        with torch.no_grad():
            embedded_src = self.embedding(src)
            embedded_src = self.dropout(embedded_src)
            encoder_outputs, hidden = self.encoder(embedded_src)
            batch_size = src.size(0)

            # Initialize beams for each sample in the batch.
            beams = [
                [(torch.tensor([sos_token], device=src.device), 0.0, hidden[:, i:i+1, :])]
                for i in range(batch_size)
            ]

            final_outputs = [None] * batch_size

            for _ in range(max_len):
                new_beams = []
                all_finished = True

                for i in range(batch_size):
                    temp_beams = []
                    for seq, score, h in beams[i]:
                        # If the last token is EOS, keep this beam unchanged.
                        if seq[-1].item() == eos_token:
                            temp_beams.append((seq, score, h))
                            continue

                        # Generate the next token probabilities.
                        last_token = seq[-1].unsqueeze(0).unsqueeze(0)
                        embedded = self.embedding(last_token)
                        embedded = self.dropout(embedded)
                        output, h_new = self.decoder(embedded, h)
                        logits = self.fc(output.squeeze(1))
                        log_probs = torch.log_softmax(logits, dim=-1)

                        # Select top beam_width tokens.
                        topk_log_probs, topk_indices = torch.topk(log_probs, beam_width)
                        for k in range(beam_width):
                            new_seq = torch.cat([seq, topk_indices[0, k].unsqueeze(0)], dim=0)
                            new_score = score + topk_log_probs[0, k].item()
                            temp_beams.append((new_seq, new_score, h_new))

                    temp_beams = sorted(temp_beams, key=lambda x: x[1], reverse=True)[:beam_width]
                    new_beams.append(temp_beams)
                    if any(b[0][-1].item() != eos_token for b in temp_beams):
                        all_finished = False

                beams = new_beams
                if all_finished:
                    break

            # Choose the best beam from each sample.
            for i in range(batch_size):
                best_seq, best_score, _ = sorted(beams[i], key=lambda x: x[1], reverse=True)[0]
                final_outputs[i] = best_seq

            # Return a list of 1D tensors (each sequence can have different lengths).
            return final_outputs

# --------------------- FMAD Training Functions using autograd ---------------------
def train_model_fmad(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    epoch_start = time.time()
    num_batches = len(dataloader)

    for batch_idx, batch in enumerate(dataloader, 1):
        input_ids = batch['input_ids'].to(device)
        target_ids = batch['target_ids'].to(device)
        optimizer.zero_grad()

        # Use mixed precision for the forward pass and loss computation.
        with torch.cuda.amp.autocast():
            output = model(input_ids, target_ids[:, :-1])
            loss = criterion(output.transpose(1, 2), target_ids[:, 1:])
        total_loss += loss.item()

        # Compute gradients using autograd.grad to avoid reference cycles.
        # (create_graph=True is maintained to mimic FMAD style.)
        grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        for param, grad in zip(model.parameters(), grads):
            param.grad = grad
        optimizer.step()

        if batch_idx % 500 == 0 or batch_idx == num_batches:
            print(f"  Batch {batch_idx}/{num_batches}: Loss = {loss.item():.4f}")

    epoch_end = time.time()
    epoch_duration = epoch_end - epoch_start
    print(f"FMAD epoch completed in {epoch_duration:.2f} seconds.")
    avg_loss = total_loss / num_batches
    print("FMAD Training Loss for epoch:", avg_loss)
    return avg_loss

def evaluate_model(model, dataloader, tokenizer, device, sos_token, eos_token):
    model.eval()
    rouge = ROUGEScore()
    predictions = []
    references = []
    eval_start = time.time()
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            target_ids = batch['target_ids'].to(device)
            outputs = model.generate(input_ids, sos_token, eos_token, beam_width=3)
            for pred_ids, tgt_ids in zip(outputs, target_ids):
                pred_text = tokenizer.decode(pred_ids, skip_special_tokens=True)
                tgt_text = tokenizer.decode(tgt_ids, skip_special_tokens=True)
                predictions.append(pred_text)
                references.append(tgt_text)
    scores = rouge(predictions, references)
    eval_end = time.time()
    eval_duration = eval_end - eval_start
    print(f"Evaluation completed in {eval_duration:.2f} seconds.")
    print("Evaluation ROUGE scores:", scores)
    return scores

# --------------------- Setup and Training ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
pad_idx = tokenizer.pad_token_id
sos_token = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
eos_token = tokenizer.eos_token_id

# Hyperparameters
embed_dim = 256
hidden_dim = 512
num_layers = 1
dropout_rate = 0.3
batch_size = 32
num_epochs = 10
learning_rate = 0.003
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

# Load 50% of each split of the CNN/DailyMail dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_dataset = dataset["train"].shuffle(seed=42).select(range(len(dataset["train"]) // 2))
val_dataset = dataset["validation"].shuffle(seed=42).select(range(len(dataset["validation"]) // 2))
test_dataset = dataset["test"].shuffle(seed=42).select(range(len(dataset["test"]) // 2))

train_data = SummarizationDataset(train_dataset, tokenizer)
val_data = SummarizationDataset(val_dataset, tokenizer)
test_data = SummarizationDataset(test_dataset, tokenizer)

# Optimize DataLoader: use more workers and pin memory.
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

vocab_size = tokenizer.vocab_size
model = Seq2SeqGRUModel(vocab_size, embed_dim, hidden_dim, num_layers, pad_idx, dropout=dropout_rate).to(device)

print("Starting FMAD-based training...")
for epoch in range(num_epochs):
    print(f"\n=== FMAD Epoch {epoch + 1}/{num_epochs} ===")
    epoch_start_time = time.time()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    train_loss = train_model_fmad(model, train_loader, optimizer, criterion, device)
    epoch_end_time = time.time()
    total_epoch_time = epoch_end_time - epoch_start_time
    minutes = total_epoch_time // 60
    seconds = total_epoch_time % 60
    print(f"Epoch {epoch + 1} finished in {int(minutes)} minutes {seconds:.2f} seconds.")
    print(f"FMAD Training Loss: {train_loss:.4f}\n")



Starting FMAD-based training...

=== FMAD Epoch 1/10 ===


  with torch.cuda.amp.autocast():


  Batch 500/4487: Loss = 6.7012
  Batch 1000/4487: Loss = 6.7360
  Batch 1500/4487: Loss = 6.5546
  Batch 2000/4487: Loss = 6.8565
  Batch 2500/4487: Loss = 6.4651
  Batch 3000/4487: Loss = 6.2516
  Batch 3500/4487: Loss = 6.4660
  Batch 4000/4487: Loss = 6.2381
  Batch 4487/4487: Loss = 6.3543
FMAD epoch completed in 917.20 seconds.
FMAD Training Loss for epoch: 6.4670678153185746
Epoch 1 finished in 15 minutes 17.20 seconds.
FMAD Training Loss: 6.4671


=== FMAD Epoch 2/10 ===
  Batch 500/4487: Loss = 6.1095
  Batch 1000/4487: Loss = 6.2811
  Batch 1500/4487: Loss = 6.1863
  Batch 2000/4487: Loss = 6.0714
  Batch 2500/4487: Loss = 5.9475
  Batch 3000/4487: Loss = 5.9162
  Batch 3500/4487: Loss = 6.1775
  Batch 4000/4487: Loss = 5.9865
  Batch 4487/4487: Loss = 6.6216
FMAD epoch completed in 917.75 seconds.
FMAD Training Loss for epoch: 6.148205031949055
Epoch 2 finished in 15 minutes 17.75 seconds.
FMAD Training Loss: 6.1482


=== FMAD Epoch 3/10 ===
  Batch 500/4487: Loss = 5.9321
 

In [None]:
print("Evaluating on validation data...")
evaluate_model(model, test_loader, tokenizer, device, sos_token, eos_token)


Evaluating on validation data...
Evaluation completed in 490.17 seconds.
Evaluation ROUGE scores: {'rouge1_fmeasure': tensor(0.1198), 'rouge1_precision': tensor(0.2075), 'rouge1_recall': tensor(0.0891), 'rouge2_fmeasure': tensor(0.0127), 'rouge2_precision': tensor(0.0228), 'rouge2_recall': tensor(0.0093), 'rougeL_fmeasure': tensor(0.0966), 'rougeL_precision': tensor(0.1668), 'rougeL_recall': tensor(0.0719), 'rougeLsum_fmeasure': tensor(0.1088), 'rougeLsum_precision': tensor(0.1885), 'rougeLsum_recall': tensor(0.0808)}


{'rouge1_fmeasure': tensor(0.1198),
 'rouge1_precision': tensor(0.2075),
 'rouge1_recall': tensor(0.0891),
 'rouge2_fmeasure': tensor(0.0127),
 'rouge2_precision': tensor(0.0228),
 'rouge2_recall': tensor(0.0093),
 'rougeL_fmeasure': tensor(0.0966),
 'rougeL_precision': tensor(0.1668),
 'rougeL_recall': tensor(0.0719),
 'rougeLsum_fmeasure': tensor(0.1088),
 'rougeLsum_precision': tensor(0.1885),
 'rougeLsum_recall': tensor(0.0808)}