In [4]:
import gdown
import zipfile
import os
import faiss
import datasets

import typing
import time
from typing import Tuple, Dict
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
from datasets import load_dataset
import numpy as np

from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig, AutoConfig, \
    RagTokenizer, BartForConditionalGeneration, AlbertModel, Trainer, TrainingArguments, BatchEncoding, DataCollatorWithPadding

import torch
from torch.utils.data import DataLoader

from transformers.modeling_outputs import BaseModelOutputWithPooling

from reqs.lightning_base import BaseTransformer
from reqs.utils import *

## Loading the Dataset

In [5]:
url = "https://drive.google.com/uc?id=18xMA2wGPDXArwLyVWN3HXQaF0XnjtugF"
filepath = "data/gold"

# Check if index exists
if os.path.isfile(filepath + "/index.faiss"):
    print("File already exists")
else:

    # Download zip file using gdown
    gdown.download(url, "index.zip", quiet=False)

    # Create directory if it doesn't exist
    if not os.path.exists(filepath):
        os.makedirs(filepath)

    # Unzip file
    with zipfile.ZipFile("index.zip", 'r') as zip_ref:
        zip_ref.extractall(filepath)

    # Remove zip file
    os.remove("index.zip")

File already exists


## Creating the Model

In [6]:
encoder_model_name = "sentence-transformers/paraphrase-albert-base-v2"
encoder_model_type = "albert"
encoder_config = AutoConfig.from_pretrained(encoder_model_name, output_hidden_states=True)

generator_model_name = "facebook/bart-base"
generator_model_type = "bart"
generator_config = AutoConfig.from_pretrained(generator_model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
rag_config = RagConfig(
    question_encoder={
        "model_type": encoder_model_type,
        "config": encoder_config,
    },
    generator = {
        "model_type": generator_model_type,
        "config": generator_config
    },
    index_name="custom",
    passages_path=filepath + "/dataset",
    index_path=filepath + "/index.faiss",
)

In [8]:
rag_retriever = RagRetriever(
    config=rag_config,
    question_encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_model_name),
    generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name),
)

In [9]:
class CustomQuestionEncoder(AlbertModel):
    def forward(self, *args, **kwargs):
        # Call the original forward method
        outputs = super().forward(*args, **kwargs)
        attention_mask = kwargs.get('attention_mask', None)

        if attention_mask is None:
            # Assume all 1s if not given, use output to get mask. The final output must be two-dimensional
            attention_mask = torch.ones(outputs[0].shape[:2], device=outputs[0].device)


        token_embeddings = outputs[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        pooler_output = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        # Return pooler output, hidden states and attentions
        return BaseModelOutputWithPooling(pooler_output=pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

# Use the custom question encoder
question_encoder_model = CustomQuestionEncoder.from_pretrained(encoder_model_name)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

rag_model = RagSequenceForGeneration(
    config=rag_config,
    retriever=rag_retriever,
    question_encoder=question_encoder_model,
    generator=BartForConditionalGeneration.from_pretrained(generator_model_name),
)

rag_tokenizer = RagTokenizer(
    question_encoder=AutoTokenizer.from_pretrained(encoder_model_name),
    generator=AutoTokenizer.from_pretrained(generator_model_name),
)

## Create dataset

In [11]:
# Create pytorch dataset
class QuestionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = 512

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data.iloc[idx]['question']
        question_encoding = self.tokenizer.question_encoder(question, return_tensors="pt")

        return {**question_encoding}

In [12]:
questions_dataset = load_dataset("web_questions")

print(questions_dataset["train"][0])
questions_dataset

{'url': 'http://www.freebase.com/view/en/justin_bieber', 'question': 'what is the name of justin bieber brother?', 'answers': ['Jazmyn Bieber', 'Jaxon Bieber']}


DatasetDict({
    train: Dataset({
        features: ['url', 'question', 'answers'],
        num_rows: 3778
    })
    test: Dataset({
        features: ['url', 'question', 'answers'],
        num_rows: 2032
    })
})

In [14]:
def tokenize_function(dataset):
    return rag_tokenizer(dataset["question"], padding="max_length", truncation=True)

tokenized_datasets = questions_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(rag_tokenizer)

Fine tuning

## Initial Testing

In [15]:
rag_model.to(device)

question = "What is the capital of the Romania"
inputs = rag_tokenizer.question_encoder(question, return_tensors="pt").to(device)

generated = rag_model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=50, num_beams=4, early_stopping=False)
generated_string = rag_tokenizer.batch_decode(generated, skip_special_tokens=True)[0]

print("Question:", question)
print("Answer:", generated_string)

KeyboardInterrupt: 

In [16]:
# Adapted from https://github.com/huggingface/transformers/blob/main/examples/research_projects/rag/utils_rag.py ( To be adapted )

class Seq2SeqDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        data_dir,
        max_source_length,
        max_target_length,
        type_path="train",
        n_obs=None,
        src_lang=None,
        tgt_lang=None,
        prefix="",
    ):
        super().__init__()
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
        self.src_lens = self.get_char_lens(self.src_file)
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
        self.prefix = prefix
        if n_obs is not None:
            self.src_lens = self.src_lens[:n_obs]
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

    def __len__(self):
        return len(self.src_lens)

    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"

        # Need to add eos token manually for T5
        if isinstance(self.tokenizer, T5Tokenizer):
            source_line += self.tokenizer.eos_token
            tgt_line += self.tokenizer.eos_token

        # Pad source and target to the right
        source_tokenizer = (
            self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
        )
        target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer

        source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
        target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
            "decoder_input_ids": target_ids,
        }

    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
        target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
        tgt_pad_token_id = (
            self.tokenizer.generator.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        src_pad_token_id = (
            self.tokenizer.question_encoder.pad_token_id
            if isinstance(self.tokenizer, RagTokenizer)
            else self.tokenizer.pad_token_id
        )
        y = trim_batch(target_ids, tgt_pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
            "decoder_input_ids": y,
        }
        return batch

## Fine-tuning Code

In [17]:
# Adapted from https://github.com/huggingface/transformers/blob/main/examples/research_projects/rag/finetune_rag.py#L97

class GenerativeQAModule(BaseTransformer):
    mode = "generative_qa"
    loss_names = ["loss"]
    metric_names = ["em"] # exact match
    val_metric = "em"

    def __init__(self, hparams, **kwargs):
        # Global variables used here from simplicity

        self.retriever = rag_retriever
        prefix = rag_config.question_encoder.prefix
        super().__init__(hparams, config=rag_config, tokenizer=rag_tokenizer, model=rag_model)
        self.dataset_kwargs = {
            "prefix": prefix,
            "max_source_length": self.hparams.max_source_length,
            "data_dir": self.hparams.data_dir,
        }

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids):
        preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return lmap(str.strip, preds)

    def _step(self, batch: dict) -> Tuple:
        source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]

        rag_kwargs = {}
        decoder_input_ids = target_ids
        lm_labels = decoder_input_ids
        rag_kwargs["reduce_loss"] = True

        assert decoder_input_ids is not None

        outputs = self(
            source_ids,
            attention_mask=source_mask,
            decoder_input_ids=decoder_input_ids,
            use_cache=False,
            labels=lm_labels,
            **rag_kwargs,
        )

        loss = outputs["loss"]
        return (loss,)

    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)

        logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)}

        logs["tpb"] = (
            batch["input_ids"].ne(self.tokenizer.question_encoder.pad_token_id).sum() +
            batch["decoder_input_ids"].ne(self.tokenizer.generator.pad_token_id).sum()
        )

        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]

        gen_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }

        metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
        gen_metrics.update({k: v.item() for k, v in losses.items()})

        losses.update(gen_metrics)
        metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        preds = flatten_list([x["preds"] for x in outputs])

        return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_exact_match(preds, target)

    def _generative_step(self, batch: dict) -> dict:

        start_time = time.time()
        batch = BatchEncoding(batch).to(device=self.model.device)
        generated_ids = self.model.generate(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            do_deduplication=False,  # rag specific parameter
            use_cache=True,
            min_length=1,
            max_length=self.target_lens["val"],
        )

        gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
        target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
        loss_tensors = self._step(batch)
        base_metrics = dict(zip(self.loss_names, loss_tensors))
        gen_metrics: Dict = self.calc_generative_metrics(preds, target)

        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
        return base_metrics

    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs, prefix="test")

    def get_dataset(self, type_path) -> Seq2SeqDataset:
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
        dataset = Seq2SeqDataset(
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
        dataset = self.get_dataset(type_path)

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
        return dataloader

    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)

    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)


In [18]:
hparams = parse_sh_args("fine_tune_rag.sh")
QAModule = GenerativeQAModule(hparams)

{'data_dir': 'data/gold', 'output_dir': 'outputs', 'cache_dir': 'outputs/cache', 'model_name_or_path': 'bart', 'model_type': 'rag_sequence', 'fp16': 1, 'gpus': 1, 'profile': 1, 'do_train': 1, 'do_predict': 1, 'n_val': -1, 'train_batch_size': 8, 'eval_batch_size': 1, 'max_source_length': 128, 'max_target_length': 25, 'val_max_target_length': 25, 'test_max_target_length': 25, 'label_smoothing': 0.1, 'weight_decay': 0.001, 'adam_epsilon': 1e-08, 'max_grad_norm': 0.1, 'lr_scheduler': 'polynomial', 'learning_rate': 3e-05, 'num_train_epochs': 100, 'warmup_steps': 500, 'gradient_accumulation_steps': 1}
here
here


In [23]:
training_args = TrainingArguments(
    output_dir="train",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    evaluation_strategy="epoch",
)

# Instantiate the Trainer
trainer = Trainer(
    model=QAModule,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

trainer.train()

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


KeyboardInterrupt: 