In [None]:
import chromadb
from transformers import AutoTokenizer, AutoModel
import torch
import json

In [None]:
# Load your JSON data
with open("prompt.json") as json_file:
    data = json.load(json_file)

In [None]:

# Extract prompts and completions
prompts = [entry['prompt'] for entry in data]
completions = [entry['completion'] for entry in data]

In [None]:
# Initialize the tokenizer and model for embeddings
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

In [None]:
# Function to get embeddings for text
def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1)
    return embeddings

In [None]:
# Generate embeddings for prompts and completions
prompt_embeddings = get_embeddings(prompts)
completion_embeddings = get_embeddings(completions)

In [1]:
# Initialize Chroma DB client
client = chromadb.Client()




In [None]:
# Create a single collection for both prompts and completions
collection = client.create_collection("prompt_completion_collection")

In [2]:
collection.add(
    ids=[f"prompt_{i}" for i in range(len(data))],  # Assigning unique IDs like 'prompt_1', 'prompt_2', etc.
    documents=prompts,  # Adding prompt texts
    embeddings=prompt_embeddings.numpy().tolist(),
    metadatas=[{"type": "prompt", "id": f"prompt_{i}"} for i in range(len(data))]  # Metadata with 'type' and 'id'
)

collection.add(
    ids=[f"completion_{i}" for i in range(len(data))],  # Similarly, assign unique IDs for completions
    documents=completions,  # Adding completion texts
    embeddings=completion_embeddings.numpy().tolist(),
    metadatas=[{"type": "completion", "id": f"completion_{i}"} for i in range(len(data))]  # Metadata with 'type' and 'id'
)



In [6]:
# Function to query the prompt_completion_collection for a prompt and return the related completion
def query_prompt_for_completion(query, collection, top_k=3):
    # Generate the embedding for the query (the prompt)
    query_embedding = get_embeddings([query])
    
    # Query the collection to find the most similar prompt
    results = collection.query(query_embeddings=query_embedding.numpy().tolist(), n_results=1)
    
    # Initialize an empty list to store completions
    completions = []
    
    # Process each result (document and metadata)
    for result, metadata in zip(results['documents'], results['metadatas']):
        print("Result:", result)  # Print the result (document)
        print("Metadata:", metadata)  # Print metadata to check its structure
        
        # Check if the metadata is a list (it might be a list of dicts)
        if isinstance(metadata, list):
            # Iterate through the metadata list and look for the prompt type
            for item in metadata:
                if item.get('type') == 'prompt':  # If it's a prompt
                    # Extract the index i from the prompt_id (e.g., prompt_1 -> i = 1)
                    prompt_id = item.get('id', '')  # Fetch ID if it exists
                    if prompt_id: 
                        index = int(prompt_id.split('_')[1])  # Extract index from 'prompt_i'
                        # Fetch the corresponding completion based on the same index i
                        completion_id = f"completion_{index}"  # corresponding completion_id
                        # Retrieve the corresponding completion from the collection
                        completion = collection.get(ids=[completion_id])  # Fetch the completion
                        # Append the retrieved completion text to the completions list
                        completions.append(completion['documents'][0])  # Assuming one result for the ID
                        break  # Once found, no need to check other metadata items
        else:
            # Handle non-list metadata
            if metadata.get('type') == 'prompt':  # Handle case where metadata is not a list
                prompt_id = metadata.get('id', '')
                if prompt_id:
                    index = int(prompt_id.split('_')[1])  # Extract index from 'prompt_i'
                    completion_id = f"completion_{index}"
                    completion = collection.get(ids=[completion_id])  # Fetch the completion
                    completions.append(completion['documents'][0])  # Assuming one result for the ID
    
    return completions

Retrieved Document (Completion):
Document: risk or severity of bleeding
Metadata: {'id': 'completion_1', 'type': 'completion'}
Retrieved Document (Completion):
Document: What is the interaction between Bivalirudin and Acemetacin?
Metadata: {'id': 'prompt_1', 'type': 'prompt'}


In [10]:
# Example query to search for the related completion
query = "What is the interaction between Bivalirudin and Acemetacin?"
completion_results = query_prompt_for_completion(query, collection)

# Print the retrieved completions (answers)
print("Retrieved Completions (Answers):")
for result in completion_results:
    print(result)


Result: ['What is the interaction between Bivalirudin and Acemetacin?']
Metadata: [{'id': 'prompt_1', 'type': 'prompt'}]
Retrieved Completions (Answers):
risk or severity of bleeding


In [12]:
from transformers import pipeline

# Initialize the question-answering pipeline with a pre-trained model
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")

# Function to generate an answer using the QA model
def generate_answer(query, context):
    qa_input = {
        "question": query,
        "context": context
    }
    # Get the result from the model
    result = qa_model(qa_input)
    return result['answer']

# Example query to search for the related completion
query = "What is the interaction between Bivalirudin and Acemetacin?"

# Retrieve the completion(s) corresponding to the prompt (this part assumes you've already fetched the completion)
completion_results = query_prompt_for_completion(query, collection)

# Combine the retrieved completions into context
context = " ".join(completion_results)  # Combine completions into the context for QA

# Generate the answer based on the context
answer = generate_answer(query, context)

# Print the generated answer
print("\nGenerated Answer:")
print(answer)


Result: ['What is the interaction between Bivalirudin and Acemetacin?']
Metadata: [{'id': 'prompt_1', 'type': 'prompt'}]





Generated Answer:
risk or severity of bleeding
