In [1]:
!pip install py7zr evaluate rouge_score

Collecting py7zr
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyzstd>=0.15.9 (from py7zr)
  Downloading pyzstd-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting pyppmd<1.2.0,>=1.1.0 (from py7zr)
  Downloading pyppmd-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting pybcj<1.1.0,>=1.0.0 (from py7zr)
  Downloading pybcj-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting multivolumefile>=0.2.3 (from py7zr)
  Downloading multivolumefile-0.2.3-py3-none-any.whl.metadata (6.3 kB)
Collecting inflate64<1.1.0,>=1.0.0 (from py7zr)
  Downloading inflate64-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Co

In [2]:
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

## Data Loading

In [3]:
dataset = load_dataset("allenai/scitldr", trust_remote_code=True)
dataset

README.md:   0%|          | 0.00/8.81k [00:00<?, ?B/s]

scitldr.py:   0%|          | 0.00/7.21k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/7.56k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.16M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.20M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1992 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/618 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/619 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['source', 'source_labels', 'rouge_scores', 'paper_id', 'target'],
        num_rows: 1992
    })
    test: Dataset({
        features: ['source', 'source_labels', 'rouge_scores', 'paper_id', 'target'],
        num_rows: 618
    })
    validation: Dataset({
        features: ['source', 'source_labels', 'rouge_scores', 'paper_id', 'target'],
        num_rows: 619
    })
})

In [4]:
dataset['train'][0]

{'source': ['Due to the success of deep learning to solving a variety of challenging machine learning tasks, there is a rising interest in understanding loss functions for training neural networks from a theoretical aspect.',
  'Particularly, the properties of critical points and the landscape around them are of importance to determine the convergence performance of optimization algorithms.',
  'In this paper, we provide a necessary and sufficient characterization of the analytical forms for the critical points (as well as global minimizers) of the square loss functions for linear neural networks.',
  'We show that the analytical forms of the critical points characterize the values of the corresponding loss functions as well as the necessary and sufficient conditions to achieve global minimum.',
  'Furthermore, we exploit the analytical forms of the critical points to characterize the landscape properties for the loss functions of linear neural networks and shallow ReLU networks.',
  '

## Model Loading

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [6]:
model_name = "google/pegasus-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

tokenizer_config.json:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.09k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-large and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/260 [00:00<?, ?B/s]

In [7]:
for param in model.parameters():
    param.requires_grad = False

for layer in model.model.encoder.layers[-2:]:
    for param in layer.parameters():
        param.requires_grad = True
        
for layer in model.model.decoder.layers[-2:]:
    for param in layer.parameters():
        param.requires_grad = True

In [8]:
# for name, param in model.named_parameters():
#     print(f"{name}: {'Trainable' if param.requires_grad else 'Frozen'}")

## Data Preprocessing

In [9]:
class ScientificDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, tokenizer, max_source_length=512, max_target_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        source_text = ' '.join(item['source'])  # Join source sentences
        target_text = item['target'][0]  # Take first target summary

        # Tokenize source
        source_encoding = self.tokenizer(
            source_text,
            max_length=self.max_source_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize target
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_encoding['input_ids'].squeeze(),
            'attention_mask': source_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze()
        }

In [10]:
train_dataset = ScientificDataset(dataset['train'], tokenizer)
val_dataset = ScientificDataset(dataset['validation'], tokenizer)
test_dataset = ScientificDataset(dataset['test'], tokenizer)

In [11]:
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [12]:
metric = evaluate.load("rouge")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [13]:
def compute_metrics(predictions, labels):
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result

In [14]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
num_epochs = 5
num_training_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_training_steps // 10,
    num_training_steps=num_training_steps
)

## Model Training

In [15]:
# Faster Loop (without generating predictions)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    epoch_loss = 0
    model.train()
    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Average training loss: {avg_loss}")

    model.eval()
    val_loss = 0
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Average validation loss: {avg_val_loss}")

Epoch 1/5


100%|██████████| 498/498 [03:29<00:00,  2.38it/s]


Average training loss: 12.355617685969095


100%|██████████| 155/155 [00:39<00:00,  3.93it/s]


Average validation loss: 13.080345301474294
Epoch 2/5


100%|██████████| 498/498 [03:28<00:00,  2.38it/s]


Average training loss: 12.368804357138025


100%|██████████| 155/155 [00:39<00:00,  3.93it/s]


Average validation loss: 13.080345301474294
Epoch 3/5


100%|██████████| 498/498 [03:28<00:00,  2.39it/s]


Average training loss: 12.38977995072024


100%|██████████| 155/155 [00:39<00:00,  3.93it/s]


Average validation loss: 13.080345301474294
Epoch 4/5


100%|██████████| 498/498 [03:28<00:00,  2.38it/s]


Average training loss: 12.371697827994105


100%|██████████| 155/155 [00:39<00:00,  3.93it/s]


Average validation loss: 13.080345301474294
Epoch 5/5


100%|██████████| 498/498 [03:28<00:00,  2.38it/s]


Average training loss: 12.355947113420111


100%|██████████| 155/155 [00:39<00:00,  3.93it/s]

Average validation loss: 13.080345301474294





In [16]:
# # Slower Loop (generate predictions and calculate the rouge metric)

# calc_interval = 4  # Generate predictions every 4 batches (you can increase it to speed the loop up, but it will need more memory)
# print_interval = 128  # Print ROUGE metrics every 128 batches

# for epoch in range(num_epochs):
#     print(f"Epoch {epoch + 1}/{num_epochs}")
#     epoch_loss = 0
#     model.train()

#     # Training: Initialize storage for accumulated inputs
#     accumulated_input_ids = []
#     accumulated_attention_masks = []
#     accumulated_labels = []
#     all_train_preds = []
#     all_train_labels = []

#     for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         labels = batch["labels"].to(device)

#         optimizer.zero_grad()
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#         # Accumulate inputs and labels
#         accumulated_input_ids.append(input_ids)
#         accumulated_attention_masks.append(attention_mask)
#         accumulated_labels.append(labels)

#         # Generate predictions every `calc_interval` batches
#         if (batch_idx + 1) % calc_interval == 0:
#             with torch.no_grad():
#                 input_ids = torch.cat(accumulated_input_ids, dim=0)
#                 attention_mask = torch.cat(accumulated_attention_masks, dim=0)
#                 labels = torch.cat(accumulated_labels, dim=0)

#                 generated_ids = model.generate(
#                     input_ids=input_ids,
#                     attention_mask=attention_mask,
#                     max_length=128,
#                     num_beams=4,
#                     early_stopping=True
#                 )
#                 decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
#                 decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

#                 all_train_preds.extend(decoded_preds)
#                 all_train_labels.extend(decoded_labels)

#             # Reset accumulated inputs
#             accumulated_input_ids = []
#             accumulated_attention_masks = []
#             accumulated_labels = []

#         # Print metrics every `print_interval` batches
#         if (batch_idx + 1) % print_interval == 0:
#             train_metrics = metric.compute(predictions=all_train_preds, references=all_train_labels, use_stemmer=True)
#             print(f"Train ROUGE after {batch_idx + 1} batches: {train_metrics}")

#     # Final metrics for remaining training batches
#     if accumulated_input_ids:
#         with torch.no_grad():
#             input_ids = torch.cat(accumulated_input_ids, dim=0)
#             attention_mask = torch.cat(accumulated_attention_masks, dim=0)
#             labels = torch.cat(accumulated_labels, dim=0)

#             generated_ids = model.generate(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 max_length=128,
#                 num_beams=4,
#                 early_stopping=True
#             )
#             decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
#             decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

#             all_train_preds.extend(decoded_preds)
#             all_train_labels.extend(decoded_labels)

#             train_metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#             print(f"Final Train ROUGE for remaining batches: {train_metrics}")

#     avg_loss = epoch_loss / len(train_loader)
#     print(f"Average training loss for epoch {epoch + 1}: {avg_loss}")

#     # Final epoch ROUGE for training
#     epoch_train_metrics = metric.compute(predictions=all_train_preds, references=all_train_labels, use_stemmer=True)
#     print(f"Train ROUGE scores for epoch {epoch + 1}: {epoch_train_metrics}")

#     # Validation
#     model.eval()
#     val_loss = 0

#     # Validation: Initialize storage for accumulated inputs
#     accumulated_input_ids = []
#     accumulated_attention_masks = []
#     accumulated_labels = []
#     all_val_preds = []
#     all_val_labels = []

#     with torch.no_grad():
#         for batch_idx, batch in enumerate(tqdm(val_loader, desc="Validating")):
#             input_ids = batch["input_ids"].to(device)
#             attention_mask = batch["attention_mask"].to(device)
#             labels = batch["labels"].to(device)

#             outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#             loss = outputs.loss
#             val_loss += loss.item()

#             # Accumulate inputs and labels
#             accumulated_input_ids.append(input_ids)
#             accumulated_attention_masks.append(attention_mask)
#             accumulated_labels.append(labels)

#             # Generate predictions every `calc_interval` batches
#             if (batch_idx + 1) % calc_interval == 0:
#                 input_ids = torch.cat(accumulated_input_ids, dim=0)
#                 attention_mask = torch.cat(accumulated_attention_masks, dim=0)
#                 labels = torch.cat(accumulated_labels, dim=0)

#                 generated_ids = model.generate(
#                     input_ids=input_ids,
#                     attention_mask=attention_mask,
#                     max_length=128,
#                     num_beams=4,
#                     early_stopping=True
#                 )
#                 decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
#                 decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

#                 all_val_preds.extend(decoded_preds)
#                 all_val_labels.extend(decoded_labels)

#                 # Reset accumulated inputs
#                 accumulated_input_ids = []
#                 accumulated_attention_masks = []
#                 accumulated_labels = []

#             # Print metrics every `print_interval` batches
#             if (batch_idx + 1) % print_interval == 0:
#                 val_metrics = metric.compute(predictions=all_val_preds, references=all_val_labels, use_stemmer=True)
#                 print(f"Validation ROUGE after {batch_idx + 1} batches: {val_metrics}")

#         # Final metrics for remaining validation batches
#         if accumulated_input_ids:
#             input_ids = torch.cat(accumulated_input_ids, dim=0)
#             attention_mask = torch.cat(accumulated_attention_masks, dim=0)
#             labels = torch.cat(accumulated_labels, dim=0)

#             generated_ids = model.generate(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 max_length=128,
#                 num_beams=4,
#                 early_stopping=True
#             )
#             decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
#             decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

#             all_val_preds.extend(decoded_preds)
#             all_val_labels.extend(decoded_labels)

#             val_metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#             print(f"Final Validation ROUGE for remaining batches: {val_metrics}")

#     avg_val_loss = val_loss / len(val_loader)
#     print(f"Average validation loss for epoch {epoch + 1}: {avg_val_loss}")

#     # Final epoch ROUGE for validation
#     epoch_val_metrics = metric.compute(predictions=all_val_preds, references=all_val_labels, use_stemmer=True)
#     print(f"Validation ROUGE scores for epoch {epoch + 1}: {epoch_val_metrics}")

## Model Evaluation

In [17]:
model.eval()

all_test_preds = []
all_test_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating on Test Data"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )

        all_test_preds.extend(generated_ids.cpu().numpy())
        all_test_labels.extend(labels.cpu().numpy())

test_metrics = compute_metrics(all_test_preds, all_test_labels)
print("Test ROUGE scores:", test_metrics)

Evaluating on Test Data: 100%|██████████| 155/155 [04:12<00:00,  1.63s/it]


Test ROUGE scores: {'rouge1': 0.2818282501874282, 'rouge2': 0.09582032387882065, 'rougeL': 0.2114834523666631, 'rougeLsum': 0.21120538514214185}
