In [1]:
import random
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import transformers
import pandas as pd
from pymongo import MongoClient
import numpy as np
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput
from torch import Tensor

from typing import Tuple, List, Union
import os
from dotenv import load_dotenv
from collections import OrderedDict

load_dotenv()


  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
# For Reproducibility
random.seed(2024)
torch.manual_seed(2024)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Helper Function
def enableMultiGPU(model: AutoModel, multi_gpu: bool):
    if multi_gpu:
        model = nn.DataParallel(model)

        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        
    return model

# To Load our weights to the model (since our model was trained on multiple gpus, we have to do string processing)
def load_model_from_gpu(model_path: str, model_name: str, device : torch.device = 'cpu'):

    our_state_dict = torch.load(model_path, map_location = device)

    new_state_dict = OrderedDict()
    for k, v in our_state_dict.items():
        name = k[7:] if k.startswith('module.') else k  # Remove the 'module.' prefix
        new_state_dict[name] = v

    # Load the state dict into your model
    model = AutoModel.from_pretrained(model_name, state_dict=new_state_dict).to(device)
    return model

def load_model(model_path: str, model_name: str, device: torch.device = 'cpu'):
    model =  transformers.AutoModel.from_pretrained(model_name).to(device)
    model.load_state_dict(torch.load(model_path, map_location = device))
    return model

def load_tokenizer(tokenizer_name: str):
    return transformers.AutoTokenizer.from_pretrained(tokenizer_name)

In [4]:
multi_gpu = False
a_name = 'google/electra-small-discriminator'
q_name = 'google/electra-small-discriminator'
t_name = 'google/electra-small-discriminator'
a_path = "model/a_encoder_model.bin"
q_path = "model/q_encoder_model.bin"

# Load Passage, Models and Tokenizers:
answers = pd.read_csv('qa/answers.csv')

# load_a = load_model(a_path, a_name , device )
# load_q = load_model(q_path, q_name, device )

load_a = load_model_from_gpu(a_path, a_name, device)
load_q = load_model_from_gpu(q_path, q_name, device)

a_enc = enableMultiGPU(load_a, multi_gpu)
q_enc = enableMultiGPU(load_q, multi_gpu)

tokenizer = load_tokenizer(t_name)

In [5]:
class InferencePipeline:
    def __init__(self, q_model: AutoModel, a_model: AutoModel, tokenizer: AutoTokenizer, answer_loader: DataLoader, device: torch.device = "cpu"):
        self.q_model = q_model
        self.a_model = a_model
        self.device = device 
        self.tokenizer = tokenizer 
        self.answer_loader = answer_loader
    
    # Embed the entire answer corpus, and post them to a MongoDB database.
    def embed_passage(self, max_length: int = 512):
        self.a_model.eval()
        # Process answers in batchesclass DatabasePipeline:
    def __init__(self, q_model: AutoModel, a_model: AutoModel, tokenizer: AutoTokenizer, answer_loader: DataLoader, device: torch.device = "cpu"):
        self.q_model = q_model
        self.a_model = a_model
        self.device = device 
        self.tokenizer = tokenizer 
        self.answer_loader = answer_loader
    
    # Embed the entire answer corpus, and post them to a MongoDB database.
    def embed_passage(self, max_length: int = 512):
        self.a_model.eval()
        # Process answers in batches
        global_idx = 0
        for answers in self.answer_loader:
            with torch.no_grad():
                encoded_batch = self.tokenizer(
                    text = answers,
                    max_length = max_length,
                    truncation = True,
                    padding="max_length",
                    return_tensors = 'pt'
                )
                encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
                outputs = self.a_model(**encoded_batch)
                batch_embedding = outputs.last_hidden_state[:,0,:]

                #Insert all the output in a batch into the DataBase
                embeddings_list = []
                for embedding in batch_embedding:
                    embedding_id = global_idx
                    embedding_list = embedding.cpu().numpy().tolist()
                    document = {
                        '_id': embedding_id,
                        'embedding': embedding_list
                    }
                    embeddings_list.append(document)
                    global_idx += 1
                self.database_manager_embedding.insert_embeddings_batch(embeddings_list)
            

    def insertanswersdb(self):
        index = 0
        for ans in self.answer_loader:
            ans_list = []
            for a in ans:
                ans_id = index
                ans_lst = a
                document = {
                    '_id': f"answer_{ans_id}",
                    'ans_list': ans_lst
                }
                ans_list.append(document)
                index += 1
            self.database_manager_answer.insert_embeddings_batch(ans_list)    
    
    def connection_db(self, database_manager_embedding, database_manager_answer):
        self.database_manager_embedding = database_manager_embedding
        self.database_manager_answer = database_manager_answer

        self.database_manager_embedding.establish_connection()
        self.database_manager_answer.establish_connection()


    def disconnect_db(self):
        try:
            self.database_manager_embedding.close_connection()
            self.database_manager_answer.close_connection()
        except NameError as e:
            print(f"An error occurred: {e}") 
        global_idx = 0
        for answers in self.answer_loader:
            with torch.no_grad():
                encoded_batch = self.tokenizer(
                    text = answers,
                    max_length = max_length,
                    truncation = True,
                    padding="max_length",
                    return_tensors = 'pt'
                )
                encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
                outputs = self.a_model(**encoded_batch)
                batch_embedding = outputs.last_hidden_state[:,0,:]

                #Insert all the output in a batch into the DataBase
                embeddings_list = []
                for embedding in batch_embedding:
                    embedding_id = global_idx
                    embedding_list = embedding.cpu().numpy().tolist()
                    document = {
                        '_id': embedding_id,
                        'embedding': embedding_list
                    }
                    embeddings_list.append(document)
                    global_idx += 1
                self.database_manager_embedding.insert_embeddings_batch(embeddings_list)
            
    
    def embed_question(self, title: List[str], body: List[str], max_length: int = 512) -> BaseModelOutput:
        self.q_model.eval()
        with torch.no_grad():
            encoded_batch = self.tokenizer(
                text=title, text_pair=body,
                max_length=max_length, truncation=True,
                padding='max_length', return_tensors='pt' 
            )
        encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
        outputs = self.q_model(**encoded_batch)
        batch_embedding = outputs.last_hidden_state[:,0,:]
        return batch_embedding
    
    def inbatch_negative_sampling(self, Q: Tensor, P: Tensor) -> Tensor:
        S = (Q @ P.transpose(0,1)).to(self.device)
        return S


    def get_topk_indices(self, Q: Tensor, P: Tensor, k: int=None) -> Tuple[Tensor, Tensor]:
        S = self.inbatch_negative_sampling(Q, P)
        if k == None:
            k = len(S)
        scores, indices = torch.topk(S, k)

        return indices, scores

    def insertanswersdb(self):
        index = 0
        for ans in self.answer_loader:
            ans_list = []
            for a in ans:
                ans_id = index
                ans_lst = a
                document = {
                    '_id': f"answer_{ans_id}",
                    'ans_list': ans_lst
                }
                ans_list.append(document)
                index += 1
            self.database_manager_answer.insert_embeddings_batch(ans_list)    

    def inference(self, title: List[str], body: List[str]) -> List[List[str]]:
        Q = self.embed_question(title, body)
        P = self.database_manager_embedding.load_embeddings().to(self.device)
        idx, scores = self.get_topk_indices(Q, P, k = 5)
        idx.squeeze_(0)
        return [self.database_manager_answer.find_element(ix) for ix in idx]

    
    def connection_db(self, database_manager_embedding, database_manager_answer):
        self.database_manager_embedding = database_manager_embedding
        self.database_manager_answer = database_manager_answer

        self.database_manager_embedding.establish_connection()
        self.database_manager_answer.establish_connection()


    def disconnect_db(self):
        try:
            self.database_manager_embedding.close_connection()
            self.database_manager_answer.close_connection()
        except NameError as e:
            print(f"An error occurred: {e}") 

In [6]:
database = os.getenv("DATABASE_NAME")
collection = os.getenv("COLLECTION_NAME")

database_1 = os.getenv("DATABASE_NAME_1")
collection_1 = os.getenv("COLLECTION_NAME_1")


Passages


In [12]:
class DatabaseManager:
    def __init__(self, database: str, collection: str):

        self.database_name = database
        self.collection_name = collection
        self.client = None
        self.db = None
        self.collection = None
        
    def establish_connection(self):
        # Use environment variables for sensitive information
        user = os.getenv('MONGO_USERNAME')
        pw = os.getenv('MONGO_PASSWORD')
        link = os.getenv("MONGO_LINK")
        
        CONNECTION_STRING = f"mongodb+srv://{user}:{pw}@{link}"
        
        try: 
            self.client = MongoClient(CONNECTION_STRING)
            self.db = self.client[self.database_name]
            self.collection = self.db[self.collection_name]
        except Exception as e:
            # Log the error
            print(f"Error connecting to Database: {e}")
            raise
        
    def insert_embeddings_batch(self, embeddings: List[float]):
        if not self.collection:
            raise Exception("Database connection is not established.")
        try:
            self.collection.insert_many(embeddings, ordered=False)
        except Exception as e:
            # Log the error
            print(f"Error inserting embeddings batch: {e}")
            raise
            
    
    def load_embeddings(self) -> Tensor:
        embeddings = []
        for doc in self.collection.find({}, {'_id': 0, 'embedding': 1}):
            embeddings.append(doc['embedding'])
        return torch.tensor(embeddings, dtype=torch.float32)
    
    def find_element(self, idx: int) -> List:
        ans = self.collection.find_one({'_id': f"answer_{idx}"}, {'_id': 0, 'ans_list': 1})
        # Check if a document was found
        if ans:
            return ans['ans_list']
        else:
            # Handle the case where no document is found
            return None  
    
    def close_connection(self):
        if self.client:
            self.client.close()
            self.client = None
            self.db = None
            self.collection = None

        

In [8]:
class AnswerDataset(Dataset):
    def __init__(self, answer: List[str]):
        self.answer = answer
    
    def __len__(self) -> int:
        return len(self.answer)
    
    def __getitem__(self, index) -> str:
        return self.answer[index]

In [9]:
batch_size = 64
answers['Answer'] = answers['Answer'].fillna('').str.replace('[^a-zA-Z0-9.!,]', ' ', regex=True).replace('\s+', ' ', regex=True)
answer = answers['Answer'].tolist()

answerDataset = AnswerDataset(answer)
answer_loader = DataLoader(answerDataset, batch_size = batch_size, shuffle = False)

In [14]:
testpipeline = InferencePipeline(q_enc, a_enc, tokenizer,answer_loader, device)
database_manager = DatabaseManager(database, collection)
answer_database_manager = DatabaseManager(database_1, collection_1)

testpipeline.connection_db(database_manager, answer_database_manager)

In [15]:
title1 = "Making Creamy Mashed Potatoes"
body1 = "I want to make creamy mashed potatoes, but they always come out too lumpy or dry. What's the secret to getting them smooth and creamy?"
title = [title1]
body = [body1]
testpipeline.inference(title, body)

['I would advise against it. Cooking in acid makes vegetables firm. Sometimes this is a good thing, but if you want to mash your potatoes afterwards, you want them soft. Else you ll get the wrong texture. There are numerous recipes for adding dairy products to the mashed potatoes after they have been cooked and mashed, and they taste well. You can look for them, or experiment yourself. It s a matter of what taste texture you like, there aren t any physics involved which you d throw off with a wrong ratio. ',
 'Consider instant mashed potatoes. Mashed potatoes from dried potato flakes are a lot better than most people given them credit for, and would probably be superior to real mashed potatoes made with poor tools in a hurry. More importantly for you, the process of cooking them scales up to any reasonable quantity you just add the correct ratio of flakes, butter, and milk https idahoan.com products idahoan original mashed potatoes 26oz on the stove and you can make up to 8 liters at a

In [16]:
# Close Database Connection:
testpipeline.disconnect_db()