In [1]:
from transformers import DistilBertTokenizer, DistilBertModel
import torch
from sklearn.metrics.pairwise import cosine_similarity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load DistilBERT model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


In [3]:
# Your SQL query list mapped to descriptions
sql_queries = {
    "Get event count per host and event name": """
        SELECT h.host, t.description AS event_name, COUNT(DISTINCT e.eventid) AS event_count
        FROM events e
        JOIN triggers t ON e.objectid = t.triggerid
        JOIN functions f ON t.triggerid = f.triggerid
        JOIN items i ON f.itemid = i.itemid
        JOIN hosts h ON i.hostid = h.hostid
        WHERE e.source = 0 AND e.value = 1
        AND e.clock >= %s AND e.clock <= %s
        GROUP BY h.host, t.description;
    """,

    "Get agent availability item for a host": """
        SELECT itemid, name, value_type
        FROM items
        WHERE hostid = %s
        AND name IN ("agent availability", "Meraki: status", "ICMP ping", "ICMP Check")
        LIMIT 1;
    """,

    "Get interface bandwidth usage": """
        SELECT ROUND(AVG(hu.value)/1000,2) AS average_value, ROUND(SUM(hu.value)/1000,2) AS Total_value
        FROM history_uint hu
        JOIN items i ON hu.itemid = i.itemid
        WHERE hu.clock >= %s AND hu.clock <= %s;
    """
}

In [4]:
# Prepare embeddings for the descriptions
def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()

In [5]:
# Embed all descriptions
query_descriptions = list(sql_queries.keys())
query_embeddings = [get_embedding(desc) for desc in query_descriptions]

In [6]:
# Chatbot function
def chatbot(user_query):
    user_embedding = get_embedding(user_query)
    similarities = [cosine_similarity(user_embedding, qe)[0][0] for qe in query_embeddings]
    best_match_idx = similarities.index(max(similarities))
    best_query_desc = query_descriptions[best_match_idx]
    best_query_sql = sql_queries[best_query_desc]
    print(f"\nMatched Query Description:\n{best_query_desc}")
    print(f"\nSQL Query:\n{best_query_sql}")

In [7]:
user_input = "What is the average bandwidth usage for the last 7 days?"

In [8]:
chatbot(user_input)


Matched Query Description:
Get interface bandwidth usage

SQL Query:

        SELECT ROUND(AVG(hu.value)/1000,2) AS average_value, ROUND(SUM(hu.value)/1000,2) AS Total_value
        FROM history_uint hu
        JOIN items i ON hu.itemid = i.itemid
        WHERE hu.clock >= %s AND hu.clock <= %s;
    
