In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
from transformers import DistilBertTokenizer

# Initialize DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [4]:
import os

# Specify the directory containing your text files
data_directory = '../en'

def load_data(directory):
    documents = {}
    for filename in os.listdir(directory):
        if filename.endswith(".txt"):  
            with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
                doc_title = filename[:-4]  
                doc_text = file.read()
                documents[doc_title] = doc_text
    return documents

documents = load_data(data_directory)

# Converting to data format needed
input_ids = []
attention_masks = []

for doc_title, doc_text in documents.items():
    inputs = tokenizer.encode_plus(
        doc_text,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='tf'
    )
    input_ids.append(inputs['input_ids'])
    attention_masks.append(inputs['attention_mask'])

input_ids = tf.concat(input_ids, axis=0)
attention_masks = tf.concat(attention_masks, axis=0)

2024-06-20 16:14:57.501206: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [1,736101] vs. shape[1] = [1,332004]


InvalidArgumentError: {{function_node __wrapped__ConcatV2_N_18_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [1,736101] vs. shape[1] = [1,332004] [Op:ConcatV2] name: concat

In [None]:
from transformers import TFDistilBertForSequenceClassification

model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')

In [None]:
def retrieve(query):
    query_input = tokenizer.encode_plus(
        query,
        add_special_tokens=True,
        max_length=128,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='tf'
    )
    query_output = model(
        input_ids=query_input['input_ids'],
        attention_mask=query_input['attention_mask']
    )[0]  # Index to get the output logits
    
    # Compute cosine similarity between query and documents
    doc_outputs = model(input_ids=input_ids, attention_mask=attention_masks)[0]
    similarity_scores = tf.keras.losses.cosine_similarity(query_output, doc_outputs, axis=1)
    
    # Get the index of the most relevant document
    retrieved_doc_index = tf.argmax(similarity_scores).numpy() # Convert tensor to integer
    
    # Convert the keys of documents to a list, then index with retrieved_doc_index
    doc_keys = list(documents.keys())
    return documents[doc_keys[retrieved_doc_index]]

In [None]:
from transformers import TFGPT2LMHeadModel

gen_model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")

In [None]:
def rag(query):
    retrieved_doc = retrieve(query)
    
    # Combine query and retrieved document for generation input
    generation_input = query + ' ' + retrieved_doc
    
    # Encode input for generation
    gen_input = tokenizer.encode_plus(
        generation_input,
        add_special_tokens=True,
        max_length=512,  # Adjusted to match the actual input length
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='tf'
    )
    
    # Generate a response
    generated_response = gen_model.generate(
        input_ids=gen_input['input_ids'],
        attention_mask=gen_input['attention_mask'],
        max_length=562,  # Adjusted to allow room for new tokens to be generated
        num_beams=5,
        temperature=0.7,
        top_k=50,
        do_sample=True  # Added to allow sampling given temperature is set
    )
    
    return tokenizer.decode(generated_response[0], skip_special_tokens=True)

In [None]:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

# Define your query
query = "What is the mechanism of action of Lisinopril?"

# Call the rag function with your query
generated_response = rag(query)

# Output the generated response
print(generated_response)