In [1]:
!pip install -q peft

In [4]:
import os
import pickle

import pandas as pd
from datasets import Dataset, load_dataset
from transformers import (
    LongT5ForConditionalGeneration,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
)
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
import torch
import numpy as np

In [30]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('pszemraj/long-t5-tglobal-base-16384-book-summary')
base_model = LongT5ForConditionalGeneration.from_pretrained('pszemraj/long-t5-tglobal-base-16384-book-summary')

# freeze the model
for param in base_model.parameters():
    param.requires_grad = False

# use PEFT

# Load the config
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()

ms2_dataset = load_dataset("allenai/mslr2022", "ms2", split="train")

# Load your CSV file
# df = pd.read_csv('../experiment_1/biobert_extractive_only_training_dataset.csv')

# ---- not available yet. in the meantime:
all_extracted_summaries = []
for fpath in os.listdir('../experiment_1/biobert_extractive_only_training_dataset'):
    all_extracted_summaries.append(
        pickle.load(open(os.path.join('../experiment_1/biobert_extractive_only_training_dataset', fpath), 'rb'))
    )
df = pd.DataFrame(all_extracted_summaries)
# ----

input_texts = df['summary'].tolist()

# target texts come from ms2 dataset. match on df's review_id for order
target_texts = [
    ms2_dataset[ms2_dataset['review_id'].index(str(i))]['target'] for i in df["review_id"]
]

# Tokenize data
def tokenize_function(examples):
    model_inputs = tokenizer(examples['input_text'], padding='max_length', truncation=True, max_length=512)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['target_text'], padding='max_length', truncation=True, max_length=128)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

dataset = Dataset.from_dict({'input_text': input_texts, 'target_text': target_texts})
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Split the dataset
shuffle_dataset = tokenized_datasets.shuffle(seed=42)
train_dataset = shuffle_dataset.select(range(len(tokenized_datasets) * 8 // 10))
val_dataset = shuffle_dataset.select(range(len(tokenized_datasets) * 8 // 10, len(tokenized_datasets)))

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=2,  # Adjust batch size according to your memory constraints
    evaluation_strategy="steps",
    save_steps=10_000,
    eval_steps=10_000,
    logging_dir='./logs',
    logging_steps=500,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

trainable params: 884,736 || all params: 248,472,192 || trainable%: 0.3560704289999583


  table = cls._concat_blocks(blocks, axis=0)


Map:   0%|          | 0/130 [00:00<?, ? examples/s]



In [31]:
# Train the model
trainer.train()

  0%|          | 0/156 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'train_runtime': 209.8985, 'train_samples_per_second': 1.486, 'train_steps_per_second': 0.743, 'train_loss': 13.549718612279648, 'epoch': 3.0}


TrainOutput(global_step=156, training_loss=13.549718612279648, metrics={'train_runtime': 209.8985, 'train_samples_per_second': 1.486, 'train_steps_per_second': 0.743, 'train_loss': 13.549718612279648, 'epoch': 3.0})

In [33]:
# view results
trainer.evaluate()



  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 11.186555862426758,
 'eval_runtime': 21.3928,
 'eval_samples_per_second': 1.215,
 'eval_steps_per_second': 0.187,
 'epoch': 3.0}