In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline

In [3]:
model_name = 'distilbert/distilbert-base-uncased-distilled-squad'
# you may change the model for better performance or use chat-gpt api-key

In [4]:
model = AutoModelForQuestionAnswering.from_pretrained(model_name,output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
import pandas as pd
df = pd.read_csv('main_txt.csv')

In [6]:
df.drop(columns=['Unnamed: 0'],axis=1,inplace=True)

In [7]:
df_sampled = df.sample(n=10000, random_state=5)
# the dataset contains nearly 27000 queries and responses
# due to computation issue , i have taken 2000 but for better performance take more data

In [8]:
import re

def preprocess_text(text):
    text = re.sub(r'[{}]', '', text)
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    text = re.sub(r'\S+@\S+', '', text)
    text = re.sub(r'[^A-Za-z\s]', '', text)
    text = text.lower()
    text = text.replace('oorder', 'order')
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [9]:
df_sampled['query'] = df_sampled['query'].apply(preprocess_text)
df_sampled['response'] = df_sampled['response'].apply(preprocess_text)

In [12]:
def get_embeddings(texts):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
        
    last_hidden_states = outputs.hidden_states[-1]  # Get the last hidden states
    return last_hidden_states[:, 0, :].numpy()

In [14]:
import numpy as np
import os 
def process_in_batches(texts, batch_size=32, save_path='embeddings'):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    embeddings = []
    checkpoint_file = os.path.join(save_path, 'checkpoint.npy')
    
    if os.path.exists(checkpoint_file):
        start_batch = np.load(checkpoint_file)
        print(f"Resuming from batch {start_batch}")
    else:
        start_batch = 0
    
    for i in range(start_batch, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_embeddings = get_embeddings(batch_texts)
        embeddings.append(batch_embeddings)
        
        
        np.save(os.path.join(save_path, f'embeddings_batch_{i//batch_size}.npy'), batch_embeddings)
        np.save(checkpoint_file, np.array([i + batch_size]))
    
    if embeddings:
        return np.concatenate(embeddings, axis=0)
    else:
        all_embeddings = []
        for file in sorted(os.listdir(save_path)):
            if file.startswith('embeddings_batch_') and file.endswith('.npy'):
                all_embeddings.append(np.load(os.path.join(save_path, file)))
        return np.concatenate(all_embeddings, axis=0)

In [15]:
def process_in_batches_res(texts, batch_size=32, save_path='embeddings_res'):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    embeddings = []
    checkpoint_file = os.path.join(save_path, 'checkpoint_res.npy')
    
    if os.path.exists(checkpoint_file):
        start_batch = np.load(checkpoint_file)
        print(f"Resuming from batch {start_batch}")
    else:
        start_batch = 0
    
    for i in range(start_batch, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_embeddings = get_embeddings(batch_texts)
        embeddings.append(batch_embeddings)  
        
        np.save(os.path.join(save_path, f'embeddings_batch_{i//batch_size}.npy'), batch_embeddings)
        np.save(checkpoint_file, np.array([i + batch_size]))
    
    if embeddings:
        return np.concatenate(embeddings, axis=0)  
    else:
        all_embeddings = []
        for file in sorted(os.listdir(save_path)):
            if file.startswith('embeddings_batch_') and file.endswith('.npy'):
                all_embeddings.append(np.load(os.path.join(save_path, file)))
        return np.concatenate(all_embeddings, axis=0)

**The embeddings files in the repo are for sampled 2000 datasets**

In [18]:
query_embeddings = process_in_batches(df_sampled['query'].tolist(), batch_size=32)

In [19]:
response_embeddings = process_in_batches_res(df_sampled['response'].tolist(), batch_size=32)

In [39]:
import faiss

queries_embeddings_np = query_embeddings
responses_embeddings_np = response_embeddings

dimension = queries_embeddings_np.shape[1]
index = faiss.IndexFlatL2(dimension)

index.add(queries_embeddings_np)

faiss.write_index(index, "queries.index")


In [40]:
response_index = faiss.IndexFlatL2(dimension)

response_index.add(responses_embeddings_np)

faiss.write_index(response_index, "responses.index")


In [41]:
def retrieve_responses(query_embedding, response_index, k=10):
    D, I = response_index.search(query_embedding, k)
    return I 


In [43]:
def predict_answer(query, model, tokenizer, response_index, responses, k=10):
    query_embedding = get_embeddings([query])[0].reshape(1, -1)
    
    closest_responses_indices = retrieve_responses(query_embedding, response_index, k)
    
    retrieved_contexts = [responses[i] for i in closest_responses_indices[0]]
    
    combined_context = " ".join(retrieved_contexts)
    
    inputs = tokenizer(query, combined_context, return_tensors="pt", truncation=True, max_length=512)
    
    print("Inputs to the model:", tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    answer_start = torch.argmax(outputs.start_logits)
    answer_end = torch.argmax(outputs.end_logits) + 1
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
    
    print("Start logits:", outputs.start_logits)
    print("End logits:", outputs.end_logits)
    
    return answer, closest_responses_indices

question = "How to cancel order?"
responses = df_sampled['response'].tolist()  
answer, retrieved_responses_indices = predict_answer(question, model, tokenizer, response_index, responses)
print(f"Question: {question}\nAnswer: {answer}")
print(f"Retrieved Responses Indices: {retrieved_responses_indices}")