In [1]:
import json
import os
import uuid
from transformers import AutoTokenizer,AutoModel
import torch
import pickle
from sentence_transformers import SentenceTransformer
from transformers.pipelines import pipeline
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from bertopic.vectorizers import ClassTfidfTransformer
import nltk
import numpy as np
import re

In [None]:
model_name = 'sentence-transformers/all-mpnet-base-v2' #"BAAI/bge-small-en-v1.5"
datout = 'refertijson'

jsonkeys = ['AAID', 
            'indicazioni al test', 
            'risultato',
            'inviante',
            'materiale','materiale inviato','materiale ricevuto',
            'test','test eseguito','test richiesto',
            'dettagli',
            'interpretazione',
            'suggerimenti',
            'metodo',
            'interpretazione',
            'limiti']


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, device_map="cuda")

fnames = os.listdir(datout)
jdata = []
reports = []

for fname in fnames:
    with open(os.path.join(datout,fname),'r') as f:
        jf = json.load(f)
        jdata.append(jf)
        
for j in jdata:
    report = ''
    for jk in jsonkeys: 
        if jk in j.keys():
            report += jk + ': ' + j[jk] + '\n#####\n'

    if not 'risultato' in j.keys():
        report += 'rescode: esito sconosciuto'        
    elif j['risultato'].startswith('non'):
        report += 'rescode: caso negativo'
    else:
        report += 'rescode: caso positivo'    
        
    reports.append(report)
    

print('Dataset length:',len(jdata))


Chunk the text so as to optimize it for retrieval

In [3]:
def document_chunker(reports,
                     model_name,
                     paragraph_separator='\n#####\n',
                     chunk_size=1024,
                     separator=' ',
                     secondary_chunking_regex=r'\S+?[\.,;!?]',
                     chunk_overlap=0):
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)  # Load tokenizer for the specified model
    documents = {}  # Initialize dictionary to store results


    for text in reports:

        aaid = text[6:text.index('.txt')]
        
        paragraphs = re.split(paragraph_separator, text)[1:]
        all_chunks = {}
        for paragraph in paragraphs:
            field = paragraph.split(':')[0]
            words = paragraph.split(separator)
            current_chunk = ""
            chunks = []

            for word in words:
                new_chunk = current_chunk + (separator if current_chunk else '') + word
                if len(tokenizer.tokenize(new_chunk)) <= chunk_size:
                    current_chunk = new_chunk
                else:
                    if current_chunk:
                        chunks.append(current_chunk)
                    current_chunk = word

            if current_chunk:
                chunks.append(current_chunk)

            refined_chunks = []
            for chunk in chunks:
                if len(tokenizer.tokenize(chunk)) > chunk_size:
                    sub_chunks = re.split(secondary_chunking_regex, chunk)
                    sub_chunk_accum = ""
                    for sub_chunk in sub_chunks:
                        if sub_chunk_accum and len(tokenizer.tokenize(sub_chunk_accum + sub_chunk + ' ')) > chunk_size:
                            refined_chunks.append(sub_chunk_accum.strip())
                            sub_chunk_accum = sub_chunk
                        else:
                            sub_chunk_accum += (sub_chunk + ' ')
                    if sub_chunk_accum:
                        refined_chunks.append(sub_chunk_accum.strip())
                else:
                    refined_chunks.append(chunk)

            final_chunks = []
            if chunk_overlap > 0 and len(refined_chunks) > 1:
                for i in range(len(refined_chunks) - 1):
                    final_chunks.append(refined_chunks[i])
                    overlap_start = max(0, len(refined_chunks[i]) - chunk_overlap)
                    overlap_end = min(chunk_overlap, len(refined_chunks[i+1]))
                    overlap_chunk = refined_chunks[i][overlap_start:] + ' ' + refined_chunks[i+1][:overlap_end]
                    final_chunks.append(overlap_chunk)
                final_chunks.append(refined_chunks[-1])
            else:
                final_chunks = refined_chunks

            # Assign a UUID for each chunk and structure it with text and metadata
            for chunk in final_chunks:
                chunk_id = str(uuid.uuid4())
                all_chunks[chunk_id] = {"text": chunk, "metadata": {"file_name":aaid, "field":field}}  # Initialize metadata as dict

        # Map the document UUID to its chunk dictionary
        documents[aaid] = all_chunks

    return documents

In [None]:
if os.path.exists('ragembeddings/docs-v2'):
    with open('ragembeddings/docs-v2','br') as f:
        docs = pickle.load(f)
else:
    print('Document chunks not found. Computing them (may take several minutes)...')        
    docs = document_chunker(reports=reports,
                        model_name=model_name,
                        chunk_size=256)
    with open('ragembeddings/docs-v2','bw') as f:
        pickle.dump(docs,f)

keys = list(docs.keys())
print(len(docs))
print(keys)
print(docs[keys[0]])

In [5]:
def compute_embeddings(text):

    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to('cuda') 
    
    # Generate the embeddings 
    with torch.no_grad():    
        embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze()

    return embeddings.tolist()

In [6]:
def create_vector_store(doc_store):
    vector_store = {}
    for doc_id, chunks in doc_store.items():
        doc_vectors = {}
        for chunk_id, chunk_dict in chunks.items():
            # Generate an embedding for each chunk of text
            doc_vectors[chunk_id] = compute_embeddings(chunk_dict.get("text"))
        # Store the document's chunk embeddings mapped by their chunk UUIDs
        vector_store[doc_id] = doc_vectors
    return vector_store

In [7]:
def compute_matches(vector_store, query_str, top_k):
    """
    This function takes in a vector store dictionary, a query string, and an int 'top_k'.
    It computes embeddings for the query string and then calculates the cosine similarity against every chunk embedding in the dictionary.
    The top_k matches are returned based on the highest similarity scores.
    """
    # Get the embedding for the query string
    query_str_embedding = np.array(compute_embeddings(query_str))
    scores = {}

    # Calculate the cosine similarity between the query embedding and each chunk's embedding
    for doc_id, chunks in vector_store.items():
        for chunk_id, chunk_embedding in chunks.items():
            chunk_embedding_array = np.array(chunk_embedding)
            # Normalize embeddings to unit vectors for cosine similarity calculation
            norm_query = np.linalg.norm(query_str_embedding)
            norm_chunk = np.linalg.norm(chunk_embedding_array)
            if norm_query == 0 or norm_chunk == 0:
                # Avoid division by zero
                score = 0
            else:
                score = np.dot(chunk_embedding_array, query_str_embedding) / (norm_query * norm_chunk)

            # Store the score along with a reference to both the document and the chunk
            scores[(doc_id, chunk_id)] = score

    # Sort scores and return the top_k results
    sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)[:top_k]
    top_results = [(doc_id, chunk_id, score) for ((doc_id, chunk_id), score) in sorted_scores]

    return top_results

In [None]:
def retrieve_doc(vector_store, docs, query_str, threshold=0.7):
    """
    This function takes in a vector store dictionary, a query string, and similarity threshold.
    It computes embeddings for each atom from the query string and then calculates the cosine similarity against every chunk embedding in the dictionary.
    If a document appear for each atom with a similarity higher than the threshold, then the document is returned.
    """
    
    
    alltop = []
    for atom in query_str.split('AND'):
        scores = {}

        query_str_embedding = np.array(compute_embeddings(atom))
        for doc_id, chunks in vector_store.items():
                
            for chunk_id, chunk_embedding in chunks.items():
                
                #inviante
                if 'dr.' in atom.lower() or 'dott.' in atom.lower() or 'dottor' in atom.lower() or 'invia' in atom.lower():
                    if docs[doc_id][chunk_id]['text'].startswith('inviante') and atom.split()[-1] in docs[doc_id][chunk_id]['text']:
                        scores[(doc_id, chunk_id)] = 1
                        continue 

                #variante   
                if 'variante' in atom.lower() and 'p.' in atom:
                    if atom[atom.index('p.')+2:atom.index('p.')+10].lower() in docs[doc_id][chunk_id]['text']:
                        scores[(doc_id, chunk_id)] = 1
                        continue
                
                #gene  
                if 'gene' in atom.lower():
                    if len(atom.split()) > atom.split().index('gene') and atom[atom.split().index('gene')+1] in docs[doc_id][chunk_id]['text']: 
                        scores[(doc_id, chunk_id)] = 1
                        continue
                          

                chunk_embedding_array = np.array(chunk_embedding)
                # Normalize embeddings to unit vectors for cosine similarity calculation
                norm_query = np.linalg.norm(query_str_embedding)
                norm_chunk = np.linalg.norm(chunk_embedding_array)
                if norm_query == 0 or norm_chunk == 0:
                    # Avoid division by zero
                    score = 0
                else:
                    score = np.dot(chunk_embedding_array, query_str_embedding) / (norm_query * norm_chunk)

                # Store the score along with a reference to both the document and the chunk
                scores[(doc_id, chunk_id)] = score

        sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
        top_results = [(doc_id, chunk_id, score) for ((doc_id, chunk_id), score) in sorted_scores if score > threshold]
        print(len(sorted_scores), len(top_results))
        alltop.append(top_results)
    
    if len(alltop) > 1:
        retrieved = []
        for at in alltop[0]:
            curchunks = [at]
            for i in range(1,len(alltop)):
                found = False
                j = 0
                while j < len(alltop[i]) and found == False:
                    if alltop[i][j][0] == at[0]:
                        found = True
                        curchunks.append(alltop[i][j])
                    j += 1
                if found:
                    continue
                else:
                    break
            if found:
                retrieved += curchunks            

    else: 
        retrieved = alltop[0]    

    return retrieved    

Load vector store, if exists. Otherwise create and save it.

In [None]:
if os.path.exists('ragembeddings/mpnet-v2'):
    print('Vector store for this model already exists. Loading vectors.')
    with open('ragembeddings/mpnet-v2','br') as f:
        vec_store = pickle.load(f)
else:        
    print('Vector store for this model doesn\'t exists. Creating vectors (may take long time)...')
    vec_store = create_vector_store(docs)
    with open('ragembeddings/mpnet-v2','bw') as f:
        pickle.dump(vec_store,f)

In [None]:
#matches = compute_matches(vector_store=vec_store,query_str="casi con una mutazione nel gene SMARCB1",top_k=10)
matches = retrieve_doc(vector_store=vec_store,
                       docs=docs,
                       query_str="test intero esoma", #"variante p.His269Arg",
                       threshold=0.7)


for rep in matches:   
    print(rep)
    #print(docs[rep[0]])
    print(docs[rep[0]][rep[1]])
    print('-----')
    
doc_ids = set([m[0] for m in matches])
for d in doc_ids:
    print(d, docs[d])

print(len(doc_ids), 'retrieved documents')


Experimental setup