In [None]:
"""!pip install transformers sentence-transformers torch scikit-learn
!pip install tf-keras
"""

In [None]:
knowledge_base = [
    """AcmeTech Remote Work Policy:
    AcmeTech allows full-time employees to work remotely up to three days per week.
    Employees must be available online between 10 AM and 3 PM EST for core collaboration hours.
    All remote work arrangements must be approved by a direct manager.""",

    """AcmeTech Security Policy:
    Employees must use company-issued devices when accessing internal systems.
    Two-factor authentication is required for all internal applications.
    Sensitive data must not be stored on personal devices or cloud services.""",

    """AcmeTech Time-Off Policy:
    Full-time employees receive 15 days of paid vacation per year.
    Vacation requests must be submitted at least two weeks in advance.
    Sick leave does not count against vacation days."""
]

In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

kb_embeddings = embedding_model.encode(knowledge_base)

In [None]:
def retrieve_chunks(query, top_k=2):
    query_embedding = embedding_model.encode([query])
    similarities = cosine_similarity(query_embedding, kb_embeddings)[0]
    
    top_indices = similarities.argsort()[-top_k:][::-1]
    
    return [knowledge_base[i] for i in top_indices]

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

In [None]:
def rag_answer(query):
    retrieved_chunks = retrieve_chunks(query)

    context = "\n".join(retrieved_chunks)

    prompt = f"""
    Use the following context to answer the question.
    If the answer is not in the context, say you do not have enough information.

    Context:
    {context}

    Question:
    {query}
    """

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = model.generate(**inputs, max_new_tokens=150)

    return tokenizer.decode(outputs[0], skip_special_tokens=True), retrieved_chunks

In [None]:
answer, chunks = rag_answer("How many remote work days are allowed at AcmeTech?")
print("Retrieved Chunks:", chunks)
print("Answer:", answer)

In [None]:
answer, chunks = rag_answer("Who is the CEO of AcmeTech?")
print("Retrieved Chunks:", chunks)
print("Answer:", answer)

In [None]:
answer, chunks = rag_answer(
    "What rules apply if an employee works remotely and accesses internal systems?"
)
print("Retrieved Chunks:", chunks)
print("Answer:", answer)