In [1]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer, T5Tokenizer
from transformers import BertTokenizer
from datasets import load_dataset
import torch

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
def initialize_model():
    model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "t5-small")
    question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
    generator_tokenizer = T5Tokenizer.from_pretrained("t5-small") 
    # this had to replaced to a smaller model compared to the original BART-large, 
    # probably that was causing CUDA out of memory on a small GPU

    tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
    model.config.use_dummy_dataset = True # use dummy dataset for POC
    model.config.index_name = "exact"
    retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)
    
    model.set_retriever(retriever)
    
    model.to(device)
    
    return (model, tokenizer, retriever)

def infer(model, tokenizer, retriever):
    input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt").to(model.device)

    outputs = model(input_dict["input_ids"], labels=input_dict["labels"])

    loss = outputs.loss
    print("loss: ", loss)
    
def save_model(model, tokenizer, retriever, path="./rag_model_custom"):
    model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    retriever.save_pretrained(path)

In [4]:
def initialize_model_custom():
    model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-ctx_encoder-single-nq-base", "t5-small")
    question_encoder_tokenizer = BertTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    generator_tokenizer = T5Tokenizer.from_pretrained("t5-small")

    tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
    model.config.use_dummy_dataset = True # use dummy dataset for POC
    model.config.index_name = "exact"
    retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)
    
    model.set_retriever(retriever)
    
    model.to(device)
    
    return (model, tokenizer, retriever)

In [None]:
(model, tokenizer, retriever) = initialize_model()

In [6]:
infer(model, tokenizer, retriever)

2023-04-28 06:03:28.384254: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


loss:  tensor([28.8410], device='cuda:0', grad_fn=<AddBackward0>)


In [6]:
'''
model1 = RagSequenceForGeneration.from_pretrained("./rag_model_custom")

question_encoder_tokenizer1 = AutoTokenizer.from_pretrained("./rag_model_custom")
generator_tokenizer1 = T5.from_pretrained("./rag_model_custom")

tokenizer1 = RagTokenizer(question_encoder_tokenizer1, generator_tokenizer1)
retriever1 = RagRetriever(model.config, question_encoder_tokenizer1, generator_tokenizer1)
'''

'\nmodel1 = RagSequenceForGeneration.from_pretrained("./rag_model_custom")\n\nquestion_encoder_tokenizer1 = AutoTokenizer.from_pretrained("./rag_model_custom")\ngenerator_tokenizer1 = T5.from_pretrained("./rag_model_custom")\n\ntokenizer1 = RagTokenizer(question_encoder_tokenizer1, generator_tokenizer1)\nretriever1 = RagRetriever(model.config, question_encoder_tokenizer1, generator_tokenizer1)\n'

In [7]:
'''
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base", index_name="exact", use_dummy_dataset=True)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
model = model.to(device)
'''

'\ntokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")\nretriever = RagRetriever.from_pretrained("facebook/rag-sequence-base", index_name="exact", use_dummy_dataset=True)\nmodel = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)\nmodel = model.to(device)\n'