In [2]:
import warnings
warnings.simplefilter("ignore")

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import numpy as np
import torch

import datasets 
import pytorch_lightning as pl

from datasets import load_dataset, load_metric

from transformers import (
    AutoModel,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [3]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
# Define the LightningDataModule
class MyDataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        # Download and preprocess the data
        load_dataset("cnn_dailymail", "3.0.0", split="train[:10%]")
        load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")
    
    def setup(self, stage=None):
        # Load and preprocess the data
        train_data = load_dataset("cnn_dailymail", "3.0.0", split="train[:10%]")
        val_data = load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")

        self.train_ds = train_data.map(
            self.preprocess_function, 
            batched=True, 
            batch_size=self.batch_size, 
            remove_columns=["article", "highlights", "id"]
        )

        self.val_ds = val_data.map(
            self.preprocess_function, 
            batched=True, 
            batch_size=self.batch_size,
            remove_columns=["article", "highlights", "id"]
        )

    def preprocess_function(self, batch):
        inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
        outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
        batch["input_ids"] = inputs.input_ids
        batch["attention_mask"] = inputs.attention_mask
        batch["labels"] = outputs.input_ids.copy()
        return batch

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)

In [7]:
class MyLightningModule(pl.LightningModule):
    def __init__(self, model_name, learning_rate, weight_decay, batch_size):
        super().__init__()
        self.model_name = model_name
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.batch_size = batch_size
        
        # Load the pre-trained model and tokenizer
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)

        # Load the ROUGE metric
        self.metric = load_metric("rouge")

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return output.loss, output.logits
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('train_loss', loss, on_epoch=True, on_step=False)
        return {'loss': loss, 'logits': logits}
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('val_loss', loss, on_epoch=True, on_step=False)
        return {'loss': loss, 'logits': logits, "labels":labels}
    
    def validation_epoch_end(self, outputs):
        decoded_preds = []
        decoded_labels = []
        for output in outputs:
            logits = output['logits']
            labels = output['labels']
            decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)
            decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"])["rouge1"].mid
        
        self.log('rouge1_precision', scores.precision, prog_bar=True)
        self.log('rouge1_recall', scores.recall, prog_bar=True)
        self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer


In [8]:
torch.set_float32_matmul_precision("medium")
model = MyLightningModule(model_name="t5-small", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)
trainer = pl.Trainer(accelerator="gpu", devices=[0], max_epochs=10)
dm = MyDataModule(batch_size=16)
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)

  0%|                                                                                                                                               | 0/1795 [00:00<?, ?ba/

 57%|██████████████████████████████████████████████████████████████████████████▌                                                        | 1021/1795 [00:09<00:07, 107.16ba/s][A
 57%|███████████████████████████████████████████████████████████████████████████▎                                                       | 1032/1795 [00:09<00:07, 107.83ba/s][A
 58%|████████████████████████████████████████████████████████████████████████████▏                                                      | 1044/1795 [00:09<00:06, 109.92ba/s][A
 59%|█████████████████████████████████████████████████████████████████████████████                                                      | 1056/1795 [00:09<00:06, 112.47ba/s][A
 59%|█████████████████████████████████████████████████████████████████████████████▉                                                     | 1068/1795 [00:09<00:06, 113.56ba/s][A
 60%|██████████████████████████████████████████████████████████████████████████████▊                               

Sanity Checking DataLoader 0:   0%|                                                                                                                    | 0/2 [00:00<?, ?it/s]

AttributeError: 'list' object has no attribute 'size'

### Recap of what we did:
* Finetuned T5-Small on DailyCNN (summarize news articles) using HF Trainer and data loading
* Converted to Lightning code 

### To do next:
* Make it work with the evaluation somethings wrong now, don't think it's a big issue
* Clean up the code a bit
* Compare it with HF, add predict function, modify data loading so it's from scratch / more general way of doing it.