In [1]:
# !pip install transformers[torch]
# !pip install pandas
# !pip install datasets 

In [1]:
#imports
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import Trainer, TrainingArguments
import pandas as pd
from tqdm.notebook import tqdm

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
from datasets import load_dataset

dataset = load_dataset("tomasg25/scientific_lay_summarisation", "elife", trust_remote_code=True)

In [3]:
from elife_dataset.preprocessing import clean_dataset
dataset = clean_dataset(dataset)

In [5]:
class ElifeDataset(Dataset):
    def __init__(self, data, tokenizer, prompt_template: str):
        self.data = data
        self.tokenizer = tokenizer
        self.promp_template = prompt_template

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        article = self.promp_template.format(article=item['article'])
        summary = item['summary']
        inputs = self.tokenizer(article, return_tensors='pt', max_length=512, truncation=True)
        targets = self.tokenizer(summary, return_tensors='pt', max_length=150, truncation=True)
        return {
            'input_ids': inputs.input_ids.flatten(),
            'attention_mask': inputs.attention_mask.flatten(),
            'labels': targets.input_ids.flatten()
        }

In [6]:
model_name = "facebook/bart-base"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = BartTokenizer.from_pretrained(model_name)
prompt_template = "Summarize the following article in simple terms: {article}"

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)



BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [7]:
#training_dataset = DailyMailDataset(validation_data, tokenizer)
train = ElifeDataset(dataset['train'], tokenizer, prompt_template)

In [8]:
#validation_dataset = DailyMailDataset(train_data, tokenizer)
validation = ElifeDataset(dataset['validation'], tokenizer, prompt_template)

In [9]:
from transformers import DataCollatorForSeq2Seq

# Create a data collator specifically for sequence-to-sequence models
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,  # This ensures dynamic padding within each batch
    label_pad_token_id=-100  # Ensures labels are padded with -100 for loss calculation
)

training_args = TrainingArguments(
    output_dir='./elife_bart_model_v2/training',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    fp16=True,
    save_strategy="epoch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=validation,
    data_collator=data_collator
)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [10]:
trainer.train()

 17%|█▋        | 273/1632 [10:47<46:58,  2.07s/it]  

KeyboardInterrupt: 

In [10]:
# Save the model

model.save_pretrained("models/elife_bart_model_v2")

In [None]:
# clear the data left on the gpu
model = None