In [44]:
import gdown
import zipfile
import os
import faiss
import datasets

from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig, AutoConfig, AutoModel

## Loading the Dataset

In [2]:
url = "https://drive.google.com/uc?id=18xMA2wGPDXArwLyVWN3HXQaF0XnjtugF"
filepath = "data/gold"

# Check if index exists
if os.path.isfile(filepath + "/index.faiss"):
    print("File already exists")
else:
    
    # Download zip file using gdown
    gdown.download(url, "index.zip", quiet=False)

    # Create directory if it doesn't exist
    if not os.path.exists(filepath):
        os.makedirs(filepath)

    # Unzip file
    with zipfile.ZipFile("index.zip", 'r') as zip_ref:
        zip_ref.extractall(filepath)

    # Remove zip file
    os.remove("index.zip")

File already exists


In [3]:
# Load index
dataset = datasets.load_from_disk("data/gold/dataset")

dataset.load_faiss_index("embeddings", filepath + "/index.faiss")

## Creating the Model

In [41]:
model_name = "sentence-transformers/paraphrase-albert-base-v2"
model_type = "albert"
model_config = AutoConfig.from_pretrained(model_name)

rag_config = RagConfig(
    question_encoder={
        "model_type": model_type,
        "config": model_config,
    },
    generator = {
        "model_type": model_type,
        "config": model_config
    },
    index_name="custom",
    passages_path=filepath + "/dataset",
    index_path=filepath + "/index.faiss",
)

In [43]:
retriever = RagRetriever(
    config=rag_config,
    question_encoder_tokenizer = AutoTokenizer.from_pretrained(model_name),
    generator_tokenizer = AutoTokenizer.from_pretrained(model_name),
)

In [45]:
rag_model = RagSequenceForGeneration(
    config=rag_config,
    retriever=retriever,
    question_encoder=AutoModel.from_pretrained(model_name),
    generator=AutoModel.from_pretrained(model_name),
)

In [46]:
# Test model
input_text = "What is the capital of France?"
output_text = "The capital of France is Paris."

AssertionError:  At least one of input_ids or context_input_ids must be given