In [None]:
"""
# HuggingFace Tutorial Series
- 1. What is Huggingface?
- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc
- 3. Using the HuggingFace Pipeline (High level feature)
- 4. How the pipeline works at a lower level
- 5. HuggingFace Datasets
- 6. HuggingFace Tokenizer
- 7. HuggingFace Evaluate
- 8. HuggingFace Trainer
- 9. Putting it together to finetune a news article summarizer
- 10. Making it more general and robust with Lightning and custom data loading
"""

In [None]:
import warnings
warnings.simplefilter("ignore")

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import torch
import datasets 
import pytorch_lightning as pl
from datasets import load_dataset, load_metric

from transformers import (
    AutoModel,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

import torch
import pandas as pd
from torch.utils.data import Dataset
import pytorch_lightning as pl

torch.set_float32_matmul_precision("medium")

In [None]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
class cnn_dailymail(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=512):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        article = self.data.loc[idx, 'article']
        highlights = self.data.loc[idx, 'highlights']

        inputs = self.tokenizer(
            article,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        targets = self.tokenizer(
            highlights,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

In [None]:
class MyDataModule(pl.LightningDataModule):
    def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):
        super().__init__()
        self.train_csv = train_csv
        self.val_csv = val_csv
        self.test_csv = test_csv
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_length = max_length

    def setup(self, stage=None):
        if stage in ('fit', None):
            self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)
            self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)
        if stage in ('test', None):
            self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)


In [None]:
class MyLightningModule(pl.LightningModule):
    def __init__(self, model_name, learning_rate, weight_decay):
        super().__init__()
        self.model_name = model_name
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        
        # Load the pre-trained model and tokenizer
        self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))
        
        # Load the ROUGE metric
        self.metric = load_metric("rouge")

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return output.loss, output.logits
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)
        return {'loss': loss, 'logits': logits}
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('val_loss', loss, on_epoch=True, on_step=False)
        
        # Save logits and labels as instance attributes
        if not hasattr(self, "logits"):
            self.logits = logits
        else:
            self.logits = torch.cat((self.logits, logits), dim=0)
        
        if not hasattr(self, "labels"):
            self.labels = labels
        else:
            self.labels = torch.cat((self.labels, labels), dim=0)
            
        return {'loss': loss, 'logits': logits, "labels":labels}
    
    def on_validation_epoch_end(self):
        # Convert logits to predicted token IDs
        pred_token_ids = self.logits.argmax(dim=-1)

        # Decode predictions and labels using the saved instance attributes
        decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)

        # Compute ROUGE scores
        scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"])["rouge1"].mid

        self.log('rouge1_precision', scores.precision, prog_bar=True)
        self.log('rouge1_recall', scores.recall, prog_bar=True)
        self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)

        # Clear logits and labels instance attributes for the next validation epoch
        del self.logits
        del self.labels
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer


In [None]:
# File paths
train_csv = "train.csv"
val_csv = "validation.csv"
test_csv = "test.csv"

# Create the data module
dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)
dm.setup()

model = MyLightningModule(model_name="t5-small", learning_rate=1e-4, weight_decay=1e-5)
trainer = pl.Trainer(accelerator="gpu", devices=[0], max_epochs=1, precision=16)
trainer.fit(model, datamodule=dm)

In [None]:
http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb

### next steps:
* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?

#### what we've done:
* Change the data loading so it's more general, meaning on the fly loading from disk
* add torch.compile
* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)
* add tensorboard visualization
* not use pretrained weights but from scratch to ensure that training setup works and actually improving
* 2. Create an inference step, send in news article -> get summary, check that it works
