In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sentence_transformers import SentenceTransformer, util
import re




In [21]:
access_token="your access token"
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=access_token)
model = AutoModelForCausalLM.from_pretrained(model_name,token=access_token)
tokenizer.pad_token = tokenizer.eos_token
# Load embedding model (use a lightweight one like 'all-MiniLM-L6-v2')
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [57]:
def context_retrieval(question,context):
    
    # Load embedding model (use a lightweight one like 'all-MiniLM-L6-v2')
    embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    sentences = re.findall(r'[^.]+\.?', context)  # Capture sentences including their full stop
    # Context divided into list
    context = [sentence.strip() for sentence in sentences if sentence]
    
    # Encode question and context sentences
    question_embedding = embedding_model.encode(question, convert_to_tensor=True)
    context_embeddings = embedding_model.encode(context, convert_to_tensor=True)
    
    # Compute cosine similarity between question and each sentence
    similarities = util.pytorch_cos_sim(question_embedding, context_embeddings)
    
    # Find the most relevant sentence (highest similarity score)
    best_match_idx = torch.argmax(similarities).item()
    
    relevant_sentence = context[best_match_idx]

    return relevant_sentence
    
    
def truncate_at_last_full_stop(text):
    last_dot_index = text.rfind(".")  # Find the last full stop
    if last_dot_index != -1:  
        return text[:last_dot_index + 1]  # Keep text up to the last full stop
    return text

def question_splitter(question):
    pattern = r'\b(?:and|or|also|as well as|besides|in addition to|moreover|furthermore|plus|along with|together with|what about|how about|while|whereas|likewise|similarly|on the other hand|not to mention|coupled with|additionally|then|next|after that|before that|simultaneously|just like|compared to|contrary to|beside that)\b'

    # Find all connectors in the question
    connectors = re.findall(pattern, question, flags=re.IGNORECASE)
    
    # Split the question into parts using connectors
    split_sentences = [q.strip() for q in re.split(pattern, question, flags=re.IGNORECASE) if q.strip()]
    
    return split_sentences, connectors  # Return list of sentences + list of connectors used

# Test Cases



In [60]:

whole_context = "The Eiffel Tower is located in Paris, France. Construction of eiffel tower was completed in 1889. Eiffel tower stands 330 meters tall."
# question = "Where is the Eiffel Tower located?"
# question = "how tall is Eiffel Tower?"
question = "where is eiffel tower located and when was Eiffel Tower constructed and how tall it is?"
answer_format = "Target:"
answer_list = []
actual_answer=''


question_list,connectors = question_splitter(question)
for q in question_list:
    relevant_context=context_retrieval(q,whole_context)
    input_text = f"Input: {relevant_context}\n{q}\n{answer_format}"
    inputs = tokenizer(input_text, return_tensors="pt",truncation=True, padding=True)

    output = model.generate(**inputs, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,num_beams=5)
    answer=tokenizer.decode(output[0], skip_special_tokens=True)

    final_answer=truncate_at_last_full_stop(answer.replace(input_text,''))
    answer_list.append(final_answer)
    
for idx,ans in enumerate(answer_list):
    try:
        if len(answer_list) >1:
            if  idx<len(answer_list)-1:
                actual_answer= actual_answer + f"{ans} {connectors[idx]} "
            else:
                actual_answer= actual_answer + f"{ans}"
        else:
            actual_answer= actual_answer + f"{ans}"
    except:
        pass
actual_answer= actual_answer.replace('.','')
print(actual_answer)

 The Eiffel Tower is located in Paris, France and  Construction of Eiffel Tower was completed in 1889 and  Eiffel tower stands 330 meters tall
