In [3]:
import fitz  # PyMuPDF

# Here we load the PDF file
pdf_file = "HolisticApplications.pdf"
doc = fitz.open(pdf_file)

# Let's extract text from each page
text = ""
for page in doc:
    text += page.get_text()

# Save the text to a file with UTF-8 encoding
with open("HolisticApplications.txt", "w", encoding="utf-8") as file:
    file.write(text)


### Step 1: Text Preprocessing and Dataset Preparation

In [13]:
from datasets import Dataset, load_from_disk

# Read and prepare text chunks
text_chunks = [
    {"text": chunk} for chunk in open("HolisticApplications.txt", "r", encoding="utf-8").read().split('\n\n')
]

# Convert the list of text chunks into a Hugging Face Dataset
dataset = Dataset.from_dict({"text": [item['text'] for item in text_chunks]})

# Save the dataset locally for indexing
dataset.save_to_disk("rag_dataset")

# Load the dataset back using load_from_disk
dataset = load_from_disk("rag_dataset")


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

### Step 2: FAISS Indexing and Retrieval

In [16]:
# Load the dataset
dataset = load_from_disk("rag_dataset")

# Initialize the tokenizer and model for the RAG model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")

# Tokenize the text data
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)

# Apply tokenization
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Initialize FAISS index
d = 768  # Let's take 768 as the dimension of the embeddings
faiss_index = faiss.IndexFlatL2(d)

# Initialize the RAG model
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base")
encoder = rag_model.question_encoder

# Convert the tokenized dataset to vectors using the encoder
vectors = []

for i, example in enumerate(tokenized_dataset):
    input_ids = torch.tensor(example["input_ids"]).unsqueeze(0)  # Add batch dimension
    attention_mask = torch.tensor(example["attention_mask"]).unsqueeze(0)

    # Encode the input and get the embeddings
    with torch.no_grad():
        outputs = encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]  # Access the first element of the output tuple

        # Pick and use the hidden state as the embedding
        embeddings = last_hidden_state

        # Check if the embeddings have the correct shape
        assert embeddings.shape == (1, d), f"Unexpected embedding shape: {embeddings.shape}"
        vectors.append(embeddings.numpy())

# Convert list of vectors to a numpy array and flatten the list of arrays
vectors = np.vstack(vectors).astype("float32")

# Debugging: Print vector shapes
print(f"Shape of the final vector array: {vectors.shape}")

# Convert list of vectors to a FAISS index
try:
    faiss_index.add(vectors)
    print("FAISS index created successfully.")
except Exception as e:
    print(f"Error during FAISS indexing: {e}")

# Save the FAISS index for later use
faiss.write_index(faiss_index, "holistic_applications_index.faiss")

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class 

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Shape of the final vector array: (1, 768)
FAISS index created successfully.


### Step 3: Set Up the RAG Model and Retriever

In [None]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration

# Load the FAISS index
index = faiss.read_index("holistic_applications_index.faiss")

# Initialize the retriever
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base",
    index=index,
    passages_dataset=tokenized_dataset
)

# Load the RAG model for sequence generation
rag_model = RagSequenceForGeneration.from_pretrained(
    "facebook/rag-token-base",
    retriever=retriever
)

### Step 4: Build the Question-Answering Pipeline

In [None]:
from transformers import pipeline

# Initialize the QA pipeline
qa_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer)

# Example question
question = "What are the benefits of holistic treatments in veterinary medicine?"

# Generate an answer
answer = qa_pipeline(question)
print(answer[0]['generated_text'])

### Step 5: Fine-Tuning

In [None]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir="./rag_model",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Initialize the Trainer
trainer = Trainer(
    model=rag_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./fine_tuned_rag_model")

### Step 6: Deployment and Inference

In [None]:
# Load the fine-tuned model
rag_model = RagSequenceForGeneration.from_pretrained("./fine_tuned_rag_model")

# Use the QA pipeline again with the fine-tuned model
qa_pipeline = pipeline("text2text-generation", model=rag_model, tokenizer=tokenizer)

# Example question from a standard pet owner
question = "How effective is acupuncture for treating chronic pain in animals?"

# Generate an answer
answer = qa_pipeline(question)
print(answer[0]['generated_text'])