# ENCODER

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import AutoModel, AutoTokenizer



In [None]:
tokenizer = AutoTokenizer.from_pretrained('ai4bharat/indic-bert')
model = AutoModel.from_pretrained('ai4bharat/indic-bert')

In [None]:
import torch
import torch.nn.functional as F

#Mean Pooling - Take average of all tokens
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
#Encode text
def encode(texts):
    # Tokenize sentences
    doc_stride = 128
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512, stride=doc_stride, return_overflowing_tokens = True)
    encoded_input.pop("overflow_to_sample_mapping")

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input, return_dict=True)

    # Perform pooling
    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    return embeddings.tolist()

# MYSQL

In [None]:
import pymysql
conn = pymysql.connect(host='localhost', user='aswin', port=3306, password='Mysql@123', database='ODQA',local_infile=True)
cursor = conn.cursor()
print("odqa")
# print("this is imp")
TABLE_NAME = 'QA_DATASET'

In [None]:
def create_context_table():
    #Deleting previouslny stored table for clean run
    drop_table = "DROP TABLE IF EXISTS " + TABLE_NAME + ";"
    cursor.execute(drop_table)
    try:
        # sql = "CREATE TABLE if not exists " + TABLE_NAME + " (id TEXT, context TEXT);"
        sql = f"""
                CREATE TABLE if not exists {TABLE_NAME} (
                    id int(10) NOT NULL AUTO_INCREMENT,
                    question TEXT COLLATE utf8_bin NOT NULL,
                    context MEDIUMTEXT COLLATE utf8_bin NOT NULL,
                    answer  TEXT COLLATE utf8_bin NOT NULL,
                    answer_start int(5) NOT NULL,
                    PRIMARY KEY (id)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
                AUTO_INCREMENT=1 ;"""
        cursor.execute(sql)
        print(f"{TABLE_NAME} table successfully!")
    except Exception as e:
        print("can't create a MySQL table: ", e)



In [None]:
def execute_query(query):
    try:
        cursor.execute(query)
        rows = cursor.fetchall()
        return rows
    except Exception as e:
        print("can't create a MySQL table: ", e)



In [None]:
def insert_data(dataset):
    """
    context should be array of contexts
    [con1, con2, ...]
    """
    # q = "select count(id) from context"
    # res = execute_query(q)
    # current_size = res[0][0]
    # next = current_size+1
    for data in dataset:
        sql = "INSERT INTO QA_DATASET (question, context, answer, answer_start) VALUES (%s, %s, %s, %s)"
        cursor.execute(sql, (data["question"], data["context"], data["answer"], data["answer_start"]))
        # next+=1 
    conn.commit()

def extract_context(id):
    q = f"select context from QA_DATASET where id = {id}"
    res = execute_query(q)
    return res[0]

# MILVUS

In [2]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect()
import odqa_mysql as odqa_mysql
import odqa_encoder as odqa_encoder

odqa


Some weights of the model checkpoint at ai4bharat/indic-bert were not used when initializing AlbertModel: ['sop_classifier.classifier.bias', 'predictions.dense.weight', 'predictions.decoder.bias', 'predictions.LayerNorm.weight', 'predictions.decoder.weight', 'sop_classifier.classifier.weight', 'predictions.LayerNorm.bias', 'predictions.dense.bias', 'predictions.bias']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
from tqdm.autonotebook import tqdm

In [3]:
import json

TABLE_NAME = 'question_answering'
collection = None



In [4]:
#Deleting previouslny stored table for clean run
def create_mqa():
    if utility.has_collection(TABLE_NAME):
        collection = Collection(name=TABLE_NAME)
        collection.drop()

    field1 = FieldSchema(name="ind", dtype=DataType.INT64, descrition="int64", is_primary=True)
    field2 = FieldSchema(name="id", dtype=DataType.INT64, descrition="int64", is_primary=False)
    field3 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="float vector",dim=768, is_primary=False)
    schema = CollectionSchema(fields=[field1, field2, field3], description="collection description")
    collection = Collection(name=TABLE_NAME, schema=schema)
    
    default_index = {"index_type": "IVF_FLAT", "metric_type": 'IP', "params": {"nlist": 200}}
    collection.create_index(field_name="embedding", index_params=default_index)

if utility.has_collection(TABLE_NAME):
    collection = Collection(name=TABLE_NAME)

In [8]:
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}

def find_similar(emb):
    collection.load()
    return collection.search(
	data=emb, 
	anns_field="embedding", 
	param=search_params, 
	limit=10, 
	expr=None,
	consistency_level="Strong"
)



def push_context_to_milvus():
    print("\n\n")
    db_fp = r"database_handler.json"
    file = open(db_fp)
    database_handler = json.loads(file.read())
    file.close()

    start= database_handler['milvus_rows']
    end= start+database_handler["batch"]
    index = database_handler['milvus_rows']
    
    query = f"select * from context where id between {start} and {end} ;"
    res = odqa_mysql.execute_query(query)

    for id, context in tqdm(res):
        emb = odqa_encoder.encode(context)
        indexs = []
        ids = []
        for i in range(len(emb)):
            indexs.append(index)
            index+=1  
            ids.append(id)
        # print(emb, indexs, ids)
        collection.insert([indexs, ids, emb])
           
    database_handler['milvus_rows'] = end
    database_handler['milvus_index'] = index

    file = open(db_fp,"w")
    json.dump(database_handler, file)
    file.close()
    
    mysql_size = odqa_mysql.execute_query("select count(*) from QA_DATASET")[0][0]
    return f"mysql : {mysql_size}\nmilvus : {collection.num_entities}"

In [6]:
create_mqa()

In [9]:
push_context_to_milvus()






'mysql : 368\nmilvus : 229'

In [27]:
def query():
    collection.load()
    return collection.query(
	anns_field="embedding", 
	param=search_params, 
	limit=10, 
	expr="id == 6",
	consistency_level="Strong"
)

In [28]:
query()

[{'ind': 36},
 {'ind': 37},
 {'ind': 38},
 {'ind': 41},
 {'ind': 33},
 {'ind': 34},
 {'ind': 35},
 {'ind': 39},
 {'ind': 40}]

# EXTRACTOR

In [None]:
from transformers import  pipeline

# model_name = "deepset/xlm-roberta-large-squad2"
model_name = "AswiN037/xlm-roberta-squad-tamil"

answer_extract = pipeline('question-answering', model=model_name, tokenizer=model_name)
