In [1]:
import gdown
import zipfile
import os
import faiss
import datasets
import torch
import typing

from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig, AutoConfig, AutoModel, \
    RagTokenizer, BartForConditionalGeneration, AlbertModel

from transformers.modeling_outputs import BaseModelOutputWithPooling

## Loading the Dataset

In [2]:
url = "https://drive.google.com/uc?id=18xMA2wGPDXArwLyVWN3HXQaF0XnjtugF"
filepath = "data/gold"

# Check if index exists
if os.path.isfile(filepath + "/index.faiss"):
    print("File already exists")
else:
    
    # Download zip file using gdown
    gdown.download(url, "index.zip", quiet=False)

    # Create directory if it doesn't exist
    if not os.path.exists(filepath):
        os.makedirs(filepath)

    # Unzip file
    with zipfile.ZipFile("index.zip", 'r') as zip_ref:
        zip_ref.extractall(filepath)

    # Remove zip file
    os.remove("index.zip")

File already exists


## Creating the Model

In [3]:
encoder_model_name = "sentence-transformers/paraphrase-albert-base-v2"
encoder_model_type = "albert"
encoder_config = AutoConfig.from_pretrained(encoder_model_name, output_hidden_states=True)

generator_model_name = "facebook/bart-base"
generator_model_type = "bart"
generator_config = AutoConfig.from_pretrained(generator_model_name)

In [4]:
rag_config = RagConfig(
    question_encoder={
        "model_type": encoder_model_type,
        "config": encoder_config,
    },
    generator = {
        "model_type": generator_model_type,
        "config": generator_config
    },
    index_name="custom",
    passages_path=filepath + "/dataset",
    index_path=filepath + "/index.faiss",
)

In [5]:
retriever = RagRetriever(
    config=rag_config,
    question_encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_model_name),
    generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name),
)

In [6]:
class CustomQuestionEncoder(AlbertModel):
    def forward(self, *args, **kwargs):
        # Call the original forward method
        outputs = super().forward(*args, **kwargs)

        attention_mask = kwargs.get('attention_mask', None)

        if attention_mask is None:
            # Assume all 1s if not given, use output to get mask. The final output must be two-dimensional
            attention_mask = torch.ones(outputs[0].shape[:2], device=outputs[0].device)
            

        token_embeddings = outputs[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        pooler_output = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        # Return pooler output, hidden states and attentions
        return BaseModelOutputWithPooling(pooler_output=pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


# Use the custom question encoder
question_encoder_model = CustomQuestionEncoder.from_pretrained(encoder_model_name)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

rag_model = RagSequenceForGeneration(
    config=rag_config,
    retriever=retriever,
    question_encoder=question_encoder_model,
    generator=BartForConditionalGeneration.from_pretrained(generator_model_name),
)

rag_tokenizer = RagTokenizer(
    question_encoder=AutoTokenizer.from_pretrained(encoder_model_name),
    generator=AutoTokenizer.from_pretrained(generator_model_name),
)

## Testing the Model

In [8]:
rag_model.to(device)

question = "What is the capital of the Netherlands"
inputs = rag_tokenizer.question_encoder(question, return_tensors="pt").to(device)

generated = rag_model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=50, num_beams=4, early_stopping=False)
generated_string = rag_tokenizer.batch_decode(generated, skip_special_tokens=True)[0]

print("Question:", question)
print("Answer:", generated_string)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Question: What is the capital of the Netherlands
Answer: Netherlands / The Netherlands, informally Holland, is a country located in northwestern Europe with overseas territories in the Caribbean. It is the largest of the four constituent countries of the Kingdom of the Netherlands. The Netherlands consists of twelve provinces;
