In [36]:
import sys
sys.path.append("../../.")

In [37]:
from Utils.DbLoadUtils import getMongoClient

In [38]:
uri = "mongodb+srv://admin:admin@bigdata.em7viry.mongodb.net/?retryWrites=true&w=majority&appName=BigData"
mongoClient = getMongoClient(uri)

db = mongoClient["BigData"]
collection = db["MedicalLLM"]

Connection to MongoDB successful


In [39]:
from langchain_huggingface import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(
    model_name='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    multi_process=True,
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},
)



In [40]:
from langchain.vectorstores import MongoDBAtlasVectorSearch

vectorStore = MongoDBAtlasVectorSearch.from_connection_string(
    uri,
    db.name + "." + collection.name,
    embedding_model,
    relevance_score_fn = "cosine"
)


In [41]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [42]:
from langchain_community.llms import HuggingFaceHub

modelName = "google/gemma-1.1-7b-it"

hf = HuggingFaceHub(
    repo_id=modelName,
    model_kwargs={"temperature":0.5, "max_length":500})

In [43]:
def get_embedding(text: str) -> list[float]:
    if not text.strip():
        print("Attempted to get embedding for empty text.")
        return []

    embedding = embedding_model.embed_query(text)

    return embedding


In [44]:
def vector_search(user_query, collection, sito=None):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    collection (MongoCollection): The MongoDB collection to search.

    Returns:
    list: A list of matching documents.
    """

    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."


    pipeline = [
    {
        "$vectorSearch": {
            "index": "vector_index",
            "queryVector": query_embedding,
            "path": "embedding",
            "numCandidates": 5000,
            "limit": 10
        }
    },
    {
        "$project": {
            "sito": 1,
            "dottore": 1,
            "score": {"$meta": "vectorSearchScore"}
        }
    }
];

    if sito:
        pipeline.append( {
        "$match": {
            "sito": sito
        }
    })

    results = collection.aggregate(pipeline)
    return list(results)

In [45]:
from bson import json_util
import json

def create_time_series(collection):
    pipeline = [
        {
            # Convert the date string to a date object and extract the year and month
            "$project": {
                "sito": 1,
                "yearMonth": {
                    "$dateToString": {
                        "format": "%Y-%m",
                        "date": {"$dateFromString": {"dateString": "$data"}}
                    }
                }
            }
        },
        {
            # Group by sito and year-month
            "$group": {
                "_id": {
                    "sito": "$sito",
                    "yearMonth": "$yearMonth"
                },
                "count": {"$sum": 1}
            }
        },
        {
            # Sort by sito and year-month
            "$sort": {
                "_id.sito": 1,
                "_id.yearMonth": 1
            }
        },
        {
            # Reshape the documents to have sito as the main document with counts per month
            "$group": {
                "_id": "$_id.sito",
                "counts": {
                    "$push": {
                        "month": "$_id.yearMonth",
                        "count": "$count"
                    }
                }
            }
        },
        {
            # Sort the results by sito
            "$sort": {"_id": 1}
        }
    ]
    
    results = collection.aggregate(pipeline)
    # Convert the aggregation cursor to a list of dictionaries
    return json.loads(json_util.dumps(results))

# Call the function with your collection
# print(create_time_series(your_collection))


In [46]:
def getCount(collection) -> dict:
    pipeline = [
        {
            "$group": {
                "_id": "$sito",  # Use "$sito" to reference the field name
                "count": {"$sum": 1}  # Correct the syntax for $sum
            }
        },
        {
            "$sort": {"count": -1}  # Correct the syntax for $sort
        }
    ]
    results = collection.aggregate(pipeline)
    return {result['_id']: result['count'] for result in results}  # Convert cursor to dictionary


In [47]:
data = create_time_series(collection)
data

[{'_id': 'DIRE',
  'counts': [{'month': '2001-02', 'count': 1},
   {'month': '2001-08', 'count': 1},
   {'month': '2001-09', 'count': 2},
   {'month': '2001-11', 'count': 1},
   {'month': '2002-04', 'count': 1},
   {'month': '2002-05', 'count': 6},
   {'month': '2002-06', 'count': 2},
   {'month': '2002-07', 'count': 1},
   {'month': '2002-09', 'count': 3},
   {'month': '2002-10', 'count': 1},
   {'month': '2003-01', 'count': 1},
   {'month': '2003-02', 'count': 1},
   {'month': '2003-03', 'count': 2},
   {'month': '2003-04', 'count': 1},
   {'month': '2003-05', 'count': 1},
   {'month': '2003-06', 'count': 4},
   {'month': '2003-07', 'count': 4},
   {'month': '2003-09', 'count': 1},
   {'month': '2003-12', 'count': 1},
   {'month': '2004-06', 'count': 8},
   {'month': '2004-07', 'count': 2},
   {'month': '2004-08', 'count': 9},
   {'month': '2004-09', 'count': 30},
   {'month': '2004-10', 'count': 9},
   {'month': '2004-11', 'count': 15},
   {'month': '2004-12', 'count': 6},
   {'mont

In [48]:
vector_search("Ciao, cosa devo fare per capire se sono celiaco?", collection)

[{'_id': ObjectId('666e2473275c3b68247bf041'),
  'sito': 'DIRE',
  'dottore': 'TITTOBELLOE\' probabile che in lei si attivino qelli che noi chiamiamo " riflessi vagali " che in lei compaiono quando ingerisce quando ingerisce alcuni alimenti freddi oppure gasati e possono interessare sia lo stomaco che il colon. Ne deve parlare con il suo medico di famiglia. Prof.             Alberto Tittobello Casa di cura privata Universitario Specialista in Gastroenterologia Milano (MI)',
  'score': 0.8295849561691284},
 {'_id': ObjectId('666e26b8275c3b68247c3d35'),
  'sito': 'DIRE',
  'dottore': 'FRANCAVILLABuongiorno Carlotta,             il sintomo che riferisce si chiama depersonalizzazione,             tipico di un disturbo ansioso,             e consiste proprio nella sensazione di essere come "staccati dal corpo". ritengo che il dosaggio della Paroxetina (EUTIMIL) sia insufficiente per il controllo della sua sintomatologia,             pertanto il mio consiglio è quello di aumentare la posolog

In [49]:
def createPrompt(query: str, site=None):
    
    docs = vector_search(query, collection, site)
    context = ""
    for doc in docs:
        context += doc["dottore"].strip() + "\n\n"
        
    istruction = """Sei un dottore che deve rispondere alle domande di un paziente. Unisci la tua conoscenza pregressa a queste risposte fornite da medici ad altri pazienti con problemi simili ma non citarle direttamente. 
Non inventare. Genera una risposta rapida e concisa, senza ripetizioni. Usa un tono professionale e senza errori grammaticali. Indica unicamente la riposta alla domanda.
Non rispondere con il tuo nome e non identificarti. Elenca delle possibili soluzione."""
    
    return f"""CONTESTO: {context}
DOMANDA: {query}
ISTRUZIONI: {istruction}
RISPOSTA:
"""

In [71]:
# Define a global variable
global_variable = "Entrambi"
# Function to update the global variable based on the dropdown selection
def update_global_variable(selection):
    global global_variable
    global_variable = selection
    return f"Global variable updated to: {global_variable}"

def answer(query: str, site = None):
    if site != "Dire" and site != "Medic":
        site = None
    prompt  = createPrompt(query, site)
    print(site)
    response = hf.generate([prompt], max_new_tokens=1000, do_sample = True)
    splitted = response.generations[0][0].text.split("RISPOSTA:")[-1]
    return splitted.strip()


def answerNoRag(query: str):
    response = hf.generate([query], max_new_tokens=1000, do_sample = True)
    splitted = response.generations[0][0].text
    return splitted.strip()

In [72]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime

def plot_time_series(results):
    fig, ax = plt.subplots(figsize=(10, 6))  # Adjust the figure size as needed

    for site_data in results:
        site = site_data['_id']
        counts = site_data['counts']
        
        dates = [datetime.strptime(month_count['month'], '%Y-%m') for month_count in counts]
        values = [month_count['count'] for month_count in counts]
        
        ax.plot(dates, values, marker='o', label=site)  # Adjust marker style and size as needed

    # Improve the x-axis labels
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))  # Show a label every month
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    plt.xticks(rotation=45)  # Rotate the x-axis labels for better readability

    # Add grid lines
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)

    # Add labels and title
    ax.set_xlabel('Month', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title('Time Series of Counts by Site', fontsize=14)

    # Add a legend
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=10))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    ax.legend()

    # Show the plot
    plt.tight_layout()  # Adjust the layout to fit all elements
    return fig

def plot():
    return plot_time_series(data)



In [73]:
import gradio as gr
with gr.Blocks() as demo:
    
    with gr.Row():
        
        with gr.Column():
            inputs=gr.Dropdown(choices=['Medic', 'Dire', "Entrambi"], label="Select an option")
        with gr.Column():
            input_text = gr.Textbox(label="Enter a question")
            output_text = gr.Textbox(label="Output")
            button_compute = gr.Button("Compute")
            button_compute.click(answer, inputs=[input_text, inputs], outputs=output_text)
            
    with gr.Row():
        plot_component = gr.Plot()
        button_plot = gr.Button("Generate Plot")
        button_plot.click(plot, outputs=plot_component)
    
    
    demo.launch()

Running on local URL:  http://127.0.0.1:7872

To create a public link, set `share=True` in `launch()`.


None
Dire
None
Medic
Dire
Medic
