In [None]:
import sys
sys.path.append("../../.")

In [None]:
from Utils.DbLoadUtils import getMongoClient

In [None]:
uri = "mongodb+srv://admin:admin@bigdata.em7viry.mongodb.net/?retryWrites=true&w=majority&appName=BigData"
mongoClient = getMongoClient(uri)

db = mongoClient["BigData"]
collection = db["MedicalLLM"]

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(
    model_name='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    multi_process=True,
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},
)

In [None]:
from langchain.vectorstores import MongoDBAtlasVectorSearch

vectorStore = MongoDBAtlasVectorSearch.from_connection_string(
    uri,
    db.name + "." + collection.name,
    embedding_model,
    relevance_score_fn = "cosine"
)


In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from langchain_community.llms import HuggingFaceHub

modelName = "google/gemma-1.1-7b-it"

hf = HuggingFaceHub(
    repo_id=modelName,
    model_kwargs={"temperature":0.5, "max_length":500})

In [None]:
def get_embedding(text: str) -> list[float]:
    if not text.strip():
        print("Attempted to get embedding for empty text.")
        return []

    embedding = embedding_model.embed_query(text)

    return embedding


In [None]:
def vector_search(user_query, collection, sito=None):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    collection (MongoCollection): The MongoDB collection to search.

    Returns:
    list: A list of matching documents.
    """

    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."


    pipeline = [
    {
        "$vectorSearch": {
            "index": "vector_index",
            "queryVector": query_embedding,
            "path": "embedding",
            "numCandidates": 5000,
            "limit": 10
        }
    },
    {
        "$project": {
            "sito": 1,
            "dottore": 1,
            "score": {"$meta": "vectorSearchScore"}
        }
    }
];

    if sito:
        pipeline.append( {
        "$match": {
            "sito": sito
        }
    })

    results = collection.aggregate(pipeline)
    return list(results)

In [None]:
from bson import json_util
import json

def create_time_series(collection):
    pipeline = [
        {
            # Convert the date string to a date object and extract the year and month
            "$project": {
                "sito": 1,
                "yearMonth": {
                    "$dateToString": {
                        "format": "%Y-%m",
                        "date": {"$dateFromString": {"dateString": "$data"}}
                    }
                }
            }
        },
        {
            # Group by sito and year-month
            "$group": {
                "_id": {
                    "sito": "$sito",
                    "yearMonth": "$yearMonth"
                },
                "count": {"$sum": 1}
            }
        },
        {
            # Sort by sito and year-month
            "$sort": {
                "_id.sito": 1,
                "_id.yearMonth": 1
            }
        },
        {
            # Reshape the documents to have sito as the main document with counts per month
            "$group": {
                "_id": "$_id.sito",
                "counts": {
                    "$push": {
                        "month": "$_id.yearMonth",
                        "count": "$count"
                    }
                }
            }
        },
        {
            # Sort the results by sito
            "$sort": {"_id": 1}
        }
    ]
    
    results = collection.aggregate(pipeline)
    # Convert the aggregation cursor to a list of dictionaries
    return json.loads(json_util.dumps(results))

# Call the function with your collection
# print(create_time_series(your_collection))


In [None]:
def getCount(collection) -> dict:
    pipeline = [
        {
            "$group": {
                "_id": "$sito",  # Use "$sito" to reference the field name
                "count": {"$sum": 1}  # Correct the syntax for $sum
            }
        },
        {
            "$sort": {"count": -1}  # Correct the syntax for $sort
        }
    ]
    results = collection.aggregate(pipeline)
    return {result['_id']: result['count'] for result in results}  # Convert cursor to dictionary


In [None]:
data = create_time_series(collection)
data

In [None]:
vector_search("Ciao, cosa devo fare per capire se sono celiaco?", collection)

In [None]:
def createPrompt(query: str, site=None):
    
    docs = vector_search(query, collection, site)
    context = ""
    for doc in docs:
        context += doc["dottore"].strip() + "\n\n"
        
    istruction = """Sei un dottore che deve rispondere alle domande di un paziente. Unisci la tua conoscenza pregressa a queste risposte fornite da medici ad altri pazienti con problemi simili ma non citarle direttamente. 
Non inventare. Genera una risposta rapida e concisa, senza ripetizioni. Usa un tono professionale e senza errori grammaticali. Indica unicamente la riposta alla domanda.
Non rispondere con il tuo nome e non identificarti. Elenca delle possibili soluzione."""
    
    return f"""CONTESTO: {context}
DOMANDA: {query}
ISTRUZIONI: {istruction}
RISPOSTA:
"""

In [None]:
# Define a global variable
global_variable = "Entrambi"
# Function to update the global variable based on the dropdown selection
def update_global_variable(selection):
    global global_variable
    global_variable = selection
    return f"Global variable updated to: {global_variable}"

def answer(query: str, site = None):
    if site != "Dire" and site != "Medic":
        site = None
    prompt  = createPrompt(query, site)
    response = hf.generate([prompt], max_new_tokens=1000, do_sample = True)
    splitted = response.generations[0][0].text.split("RISPOSTA:")[-1]
    return splitted.strip()


def answerNoRag(query: str):
    response = hf.generate([query], max_new_tokens=1000, do_sample = True)
    splitted = response.generations[0][0].text
    return splitted.strip()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime

def plot_time_series(results):
    fig, ax = plt.subplots(figsize=(10, 6))  # Adjust the figure size as needed

    for site_data in results:
        site = site_data['_id']
        counts = site_data['counts']
        
        dates = [datetime.strptime(month_count['month'], '%Y-%m') for month_count in counts]
        values = [month_count['count'] for month_count in counts]
        
        ax.plot(dates, values, marker='o', label=site)  # Adjust marker style and size as needed

    # Improve the x-axis labels
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))  # Show a label every month
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    plt.xticks(rotation=45)  # Rotate the x-axis labels for better readability

    # Add grid lines
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)

    # Add labels and title
    ax.set_xlabel('Month', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title('Time Series of Counts by Site', fontsize=14)

    # Add a legend
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=10))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    ax.legend()

    # Show the plot
    plt.tight_layout()  # Adjust the layout to fit all elements
    return fig

def plot():
    return plot_time_series(data)



In [None]:
import gradio as gr
with gr.Blocks() as demo:
    
    with gr.Row():
        
        with gr.Column():
            inputs=gr.Dropdown(choices=['Medic', 'Dire', "Entrambi"], label="Seleziona un sito di origine")
        with gr.Column():
            input_text = gr.Textbox(label="Enter a question")
            output_text = gr.Textbox(label="Output")
            button_compute = gr.Button("Compute")
            button_compute.click(answer, inputs=[input_text, inputs], outputs=output_text)
            
    with gr.Row():
        plot_component = gr.Plot()
        button_plot = gr.Button("Generate Plot")
        button_plot.click(plot, outputs=plot_component)
    
    
    demo.launch()