In [1]:
# Cell 1: Imports and Configuration
import requests
import json
import threading
import uuid
import time # For potential delays or timeouts if needed

# Configuration for your Flask API
API_BASE_URL = "http://localhost:5000" # Adjust if your API is on a different host/port

In [2]:
# Cell 2: Helper Function to Get Available Models
def get_available_models(api_base_url):
    """Fetches the list of available models from the API."""
    try:
        response = requests.get(f"{api_base_url}/models")
        response.raise_for_status()  # Raises an exception for HTTP errors
        models = response.json()
        print("Available Models:")
        for model in models:
            print(f"- ID: {model.get('id')}, Name: {model.get('name')}, Type: {model.get('type')}, Source: {model.get('source_type')}")
            if model.get('source_type') == 'local':
                print(f"  Path: {model.get('path')}")
            elif model.get('source_type') == 'hub':
                if model.get('type') == 'gguf':
                    print(f"  Repo ID: {model.get('repo_id')}, Filename: {model.get('filename')}")
                else:
                    print(f"  Hub ID/Path: {model.get('path')}")
            # print(f"  Default Params: {model.get('params')}") # Uncomment for more detail
        return models
    except requests.exceptions.RequestException as e:
        print(f"Error fetching models: {e}")
        return []

# Fetch and display available models when this cell is run
AVAILABLE_MODELS = get_available_models(API_BASE_URL)

Error fetching models: HTTPConnectionPool(host='localhost', port=5000): Max retries exceeded with url: /models (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00000173BE95F100>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it'))


In [3]:
# Cell 3: Helper Function to Create a New Session
def create_new_session(api_base_url):
    """Creates a new session via the API and returns the session_id."""
    try:
        response = requests.post(f"{api_base_url}/create-session")
        response.raise_for_status()
        session_data = response.json()
        if session_data.get('status') == 'success' and session_data.get('session_id'):
            print(f"Created new session: {session_data['session_id']}")
            return session_data['session_id']
        else:
            print(f"Failed to create session: {session_data.get('message', 'Unknown error')}")
            return None
    except requests.exceptions.RequestException as e:
        print(f"Error creating session: {e}")
        return None

In [4]:
# Cell 4: Helper Function to Ask a Question on a Session (Handles SSE)
def ask_question_on_session(api_base_url, model_id, session_id, user_prompt, system_prompt, 
                            generation_params=None, model_load_params=None, temperature=0.7):
    """
    Sends a question to the chat API for a given session and streams the response.
    Returns the complete assistant response.
    """
    payload = {
        "session_id": session_id,
        "prompt": user_prompt,
        "system_prompt": system_prompt,
        "model_id": model_id,
        "temperature": temperature,  # Base temperature
        "model_specific_params": generation_params if generation_params else {},
        "model_load_params": model_load_params if model_load_params else {}
    }

    full_response_text = ""
    print(f"\n[Session: {session_id}, Model: {model_id}] Asking: {user_prompt[:100]}...")
    if system_prompt:
        print(f"System Prompt: {system_prompt[:100]}...")

    try:
        response = requests.post(f"{api_base_url}/chat", json=payload, stream=True, timeout=300) # Added timeout
        response.raise_for_status()

        for line in response.iter_lines():
            if line:
                decoded_line = line.decode('utf-8')
                if decoded_line.startswith('data: '):
                    try:
                        data_json_str = decoded_line[len('data: '):]
                        data_json = json.loads(data_json_str)
                        
                        if data_json.get('error'):
                            error_msg = f"[API Error for session {session_id}]: {data_json['error']}"
                            print(error_msg)
                            return error_msg # Return error message as response

                        if 'text_chunk' in data_json and not data_json.get('is_final'):
                            full_response_text += data_json['text_chunk']
                            # print(data_json['text_chunk'], end='', flush=True) # For live streaming in notebook
                        
                        if data_json.get('is_final'):
                            if 'full_response' in data_json: # Use server's full response if available
                                full_response_text = data_json['full_response']
                            # print("\n--- End of Stream ---")
                            break 
                    except json.JSONDecodeError:
                        print(f"\nWarning: Could not decode JSON from stream: {decoded_line}")
        
        print(f"[Session: {session_id}] Full Response: {full_response_text[:100]}...")
        return full_response_text.strip()

    except requests.exceptions.Timeout:
        error_msg = f"[API Timeout for session {session_id} while asking: {user_prompt[:50]}...]"
        print(error_msg)
        return error_msg
    except requests.exceptions.RequestException as e:
        error_msg = f"[API Request Error for session {session_id}]: {e}"
        print(error_msg)
        return error_msg

In [5]:
# Cell 5: Worker Function to Process a Query Group
def worker_process_query_group(query_group_config, model_id, api_base_url, all_results_list, results_lock):
    """
    Worker function for a thread. Processes a group of questions.
    Each query_group runs in its own session.
    """
    group_id = query_group_config["group_id"]
    print(f"Thread started for Query Group: {group_id}")

    session_id = create_new_session(api_base_url)
    if not session_id:
        print(f"Failed to create session for group {group_id}. Aborting this group.")
        with results_lock:
            all_results_list.append({
                "group_id": group_id,
                "session_id": None,
                "status": "failed_session_creation",
                "results": {}
            })
        return

    system_prompt = query_group_config.get("system_prompt", "")
    generation_params = query_group_config.get("generation_params", {})
    model_load_params = query_group_config.get("model_load_params", {})
    default_temperature = query_group_config.get("temperature", 0.7) # Can be set per group

    group_results_data = {}

    # The backend's /chat endpoint handles history accumulation based on session_id.
    # We send the system_prompt with each call in this setup, 
    # or rely on the backend to use the first system_prompt for the session.
    # The current backend app.py prepends the system_prompt if provided in the payload.
    for q_item in query_group_config["questions_and_keys"]:
        user_question = q_item["question"]
        answer_key = q_item["key"]
        
        answer = ask_question_on_session(
            api_base_url, 
            model_id, 
            session_id, 
            user_question, 
            system_prompt, # System prompt is associated with the session by the backend
            generation_params,
            model_load_params,
            temperature=default_temperature
        )
        group_results_data[answer_key] = answer
        # time.sleep(1) # Optional: small delay between questions in the same session if needed

    with results_lock:
        all_results_list.append({
            "group_id": group_id,
            "session_id": session_id,
            "status": "completed",
            "results": group_results_data
        })
    print(f"Thread finished for Query Group: {group_id}")

In [7]:
license='''
        California
        DRIVER LICENSe
        dl 11234568
        CLASS C
        EXP 08/31/2014
        END NONE
        LNCARDHOLDER FNIMA
        2570 24TH STREET ANYTOWN, CA 95818
        doB 08/31/1977 RSTR NONE
        08311977
        VETERAN
        Cordhslde
        SEX F HGT 5'-05"
        HAIR BRN WGT 125 lb
        EYES BRN
        DD 00/00/0000NNNAN/ANFD/YY
        ISS 08/31/2009
'''

In [None]:
# Cell 6: Main Execution Block

# --- Configuration ---
# Choose your model ID from the list printed by Cell 2
# Example: If you have a local GGUF model named "mistral-7b-instruct-v0.2.Q4_K_M.gguf"
# its ID might be "gguf_local_mistral-7b-instruct-v0_2_Q4_K_M_gguf"
# Or if you defined a Hub model in online_models.json with id "zephyr-7b-gguf"
CHOSEN_MODEL_ID = "gguf_local_Llama-3_2-1B-Instruct-Q8_0" # <--- !!! SET YOUR MODEL ID HERE !!!
# You can also get one from AVAILABLE_MODELS if it's not empty:
# if AVAILABLE_MODELS:
#    CHOSEN_MODEL_ID = AVAILABLE_MODELS[0]['id'] # Example: use the first available model
# else:
#    print("No models available from API. Please check server and config.")
#    CHOSEN_MODEL_ID = "default_model_id_placeholder" # Fallback if no models listed

# Define your query groups
# Each item in QUERY_GROUPS will be processed in a separate thread, each with a new session.
# Questions within a "questions_and_keys" list for a single group are asked sequentially *within the same session*.
QUERY_GROUPS = [
    {
        "group_id": "california_license_info_set1",
        "system_prompt": (
            "You are an expert AI assistant specializing in California driving licenses. "
            "Please answer the questions based on the following context. "
            "Context: The California Department of Motor Vehicles (DMV) states that the minimum age "
            "to apply for a learner's permit is 15 years and 6 months. Applicants for a REAL ID "
            "must provide proof of identity, their Social Security number (if eligible), and two "
            "proofs of California residency. The primary physical address for the CA DMV headquarters "
            "is in Sacramento, CA, USA. For a standard Class C license, vision screening is required."
        ),
        "questions_and_keys": [
            {"question": "What is the minimum age to get a learner's permit in California?", "key": "ca_permit_min_age"},
            {"question": "List the categories of documents needed for a REAL ID application in California.", "key": "ca_realid_docs"},
            {"question": "What is the city and state of the CA DMV headquarters address?", "key": "ca_dmv_hq_address"}
        ],
        "temperature": 0.01, # Override default temperature for this group
        "generation_params": {"max_tokens": 150}, # For GGUF: max_tokens; For HF: max_new_tokens
        "model_load_params": {"n_gpu_layers": -1} # Example for GGUF model
        # "model_load_params": {"use_bnb_4bit": True} # Example for Regular HF model if you want 4-bit
    },
    {
        "group_id": "new_york_license_info_set1",
        "system_prompt": (
            "You are an expert AI assistant for New York State driving licenses. "
            "Context: Standard New York State driver licenses (Class D) are typically valid for 8 years. "
            "Renewals can often be done online, by mail, or in person at a DMV office. A vision test "
            "is required for renewal, which can be done at the DMV or by an approved provider."
        ),
        "questions_and_keys": [
            {"question": "How long is a standard Class D driver's license valid in New York?", "key": "ny_license_validity"},
            {"question": "Is a vision test mandatory for renewing a NY driver's license?", "key": "ny_vision_test_renewal"}
        ],
        "temperature": 0.01,
        "generation_params": {"max_tokens": 100},
        # model_load_params can be omitted if defaults are fine or model is already loaded with desired settings
    },
    {
        "group_id": "california_license_info_set2_new_chat", # Simulates new chat on existing topic
        "system_prompt": (
            "You are an expert AI assistant specializing in California driving licenses. "
            "Please answer the questions based on the following context. "
            "Context: The California Driver Handbook outlines various traffic violations. A first-time "
            "DUI conviction can result in mandatory Ignition Interlock Device (IID) installation, "
            "license suspension, fines, and DUI program enrollment. The specific penalties can vary."
        ),
        "questions_and_keys": [
            {"question": "What are some potential penalties for a first-time DUI in California according to the handbook?", "key": "ca_dui_penalties_first"}
        ],
        "generation_params": {"max_tokens": 200},
        "model_load_params": {"n_gpu_layers": -1} 
    }
]

# --- Execution ---
if CHOSEN_MODEL_ID == "default_model_id_placeholder" and not AVAILABLE_MODELS:
    print("CRITICAL: CHOSEN_MODEL_ID is a placeholder and no models were fetched from the API.")
    print("Please ensure your Flask server is running, configured with models, and update CHOSEN_MODEL_ID.")
else:
    print(f"Using Model ID: {CHOSEN_MODEL_ID}")
    
    collected_results = []
    threads = []
    results_lock = threading.Lock()

    start_time = time.time()

    for group_config in QUERY_GROUPS:
        thread = threading.Thread(
            target=worker_process_query_group,
            args=(group_config, CHOSEN_MODEL_ID, API_BASE_URL, collected_results, results_lock)
        )
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join() # Wait for all threads to complete

    end_time = time.time()
    print(f"\n--- All threads completed in {end_time - start_time:.2f} seconds ---")

    # --- Display Results ---
    print("\n--- Collected Results ---")
    for item in collected_results:
        print(f"\nQuery Group ID: {item['group_id']}")
        print(f"Session ID: {item['session_id']}")
        print(f"Status: {item.get('status', 'N/A')}")
        if item.get("results"):
            for key, value in item["results"].items():
                print(f"  '{key}': '{value}'")
        else:
            print("  No results for this group.")
            
    # You can also save `collected_results` to a JSON file
    # with open("batch_qa_results.json", "w") as f:
    #     json.dump(collected_results, f, indent=2)
    # print("\nResults saved to batch_qa_results.json")