In [None]:
# ============================================================================
# QUESTION ‚Üí CONCEPT TAGGING FOR JUPYTER NOTEBOOK (CLEAN + ROBUST + FIXED)
# ============================================================================
# NOTE:
# - Structured Outputs JSON schema MUST have a ROOT object (not a root array).
#   So we return: { "concepts": [ ... ] }
# - Do NOT commit your API key to git. If your key was exposed, rotate it.

# ----------------------------------------------------------------------------
# STEP 1: Install dependencies (run once)
# ----------------------------------------------------------------------------
# !pip install -U openai numpy

# ----------------------------------------------------------------------------
# STEP 2: Imports + API key
# ----------------------------------------------------------------------------
import os
import json
from typing import Optional, List, Dict, Any

import numpy as np
from openai import OpenAI

# Put your key here locally (do NOT commit):
os.environ["OPENAI_API_KEY"] = ""

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

# ----------------------------------------------------------------------------
# STEP 3: Configuration
# ----------------------------------------------------------------------------
TAGGER_MODEL = "gpt-4o-mini"             # tagging model
EMBED_MODEL  = "text-embedding-3-small"  # for prefiltering concepts (optional)

# Choose how many concepts per question:
#   1 => ONE primary concept
#   2/3 => top 2‚Äì3 concepts
NUM_CONCEPTS = 3
MIN_CONFIDENCE = 0.70

# If your KG has many concepts, don't pass all of them to the LLM.
# Use embeddings to prefilter to top candidates.
USE_EMBEDDING_PREFILTER = True
CANDIDATE_POOL_SIZE = 60       # how many candidate concepts to show the LLM
EMBED_BATCH_SIZE = 256         # batch size for embedding API calls
EMBED_CACHE_PATH = "concept_embeddings.npz"  # local cache

# ----------------------------------------------------------------------------
# STEP 4: Prompt (aligned with ROOT OBJECT schema)
# ----------------------------------------------------------------------------
PROMPT = """
You are an expert database teaching assistant mapping SQL questions to the MOST relevant concepts from the knowledge graph.

You will be given:
1) A SQL problem/question statement
2) A list of candidate concept names that already exist in the knowledge graph

Your job:
- Choose the MOST relevant concepts from the candidate list
- These should be the PRIMARY SQL concepts/skills needed to solve this question
- Copy each concept name EXACTLY as it appears (preserve case, spaces, underscores)
- Do NOT invent new concepts - only use concepts from the provided list
- Prefer HIGH-LEVEL concepts (JOIN, SUBQUERY, GROUP_BY, HAVING, etc.) over specific values

Return ONLY a valid JSON OBJECT with this exact structure:
{
  "concepts": [
    {
      "concept_name": "EXACT_CONCEPT_NAME_FROM_LIST",
      "confidence": 0.90,
      "explanation": "Brief reason why this concept is needed"
    }
  ]
}

Important:
- Order by importance (most important first)
- If NUM_CONCEPTS >= 2, return 2‚Äì3 concepts if possible
- Return ONLY the JSON object, no other text
""".strip()

# ----------------------------------------------------------------------------
# STEP 5: KG concept extraction
# ----------------------------------------------------------------------------
def extract_concepts_from_kg(kg_jsonl_path: str) -> List[str]:
    """Extract all unique concept names from KG jsonl with edges containing A/B dicts."""
    concepts = set()
    with open(kg_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue

            for key in ("A", "B"):
                if key in obj and isinstance(obj[key], dict):
                    name = obj[key].get("name", "")
                    if isinstance(name, str) and len(name.strip()) > 1:
                        concepts.add(name.strip())

    concept_list = sorted(concepts)
    print(f"‚úì Extracted {len(concept_list)} unique concepts from KG")
    return concept_list

# ----------------------------------------------------------------------------
# STEP 6: Embedding helpers (optional prefilter)
# ----------------------------------------------------------------------------
def _l2_normalize(mat: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    denom = np.linalg.norm(mat, axis=1, keepdims=True) + eps
    return mat / denom

def embed_texts(texts: List[str], model: str = EMBED_MODEL, batch_size: int = EMBED_BATCH_SIZE) -> np.ndarray:
    """Embed a list of texts using OpenAI embeddings endpoint."""
    all_vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        resp = client.embeddings.create(model=model, input=batch)
        vecs = [d.embedding for d in resp.data]
        all_vecs.append(np.array(vecs, dtype=np.float32))
    return np.vstack(all_vecs)

def load_or_build_concept_embeddings(concepts: List[str], cache_path: str = EMBED_CACHE_PATH) -> np.ndarray:
    """
    Build (or load cached) normalized embeddings for concept strings.
    Cache is invalidated automatically if the concept list changes.
    """
    if os.path.exists(cache_path):
        try:
            data = np.load(cache_path, allow_pickle=True)
            cached_concepts = data["concepts"].tolist()
            cached_embeds = data["embeddings"]
            if cached_concepts == concepts:
                print(f"‚úì Loaded cached concept embeddings: {cache_path}")
                return cached_embeds
            else:
                print("‚ö†Ô∏è  Concept list changed ‚Äî rebuilding embeddings cache.")
        except Exception:
            print("‚ö†Ô∏è  Failed to load cache ‚Äî rebuilding embeddings cache.")

    print(f"üîé Building embeddings for {len(concepts)} concepts (this can take a bit)...")
    embeds = embed_texts(concepts, model=EMBED_MODEL, batch_size=EMBED_BATCH_SIZE)
    embeds = _l2_normalize(embeds)

    np.savez_compressed(cache_path, concepts=np.array(concepts, dtype=object), embeddings=embeds)
    print(f"‚úì Saved embeddings cache to: {cache_path}")
    return embeds

def select_candidate_concepts(
    question_text: str,
    concepts: List[str],
    concept_embeds: Optional[np.ndarray],
    k: int = CANDIDATE_POOL_SIZE
) -> List[str]:
    """Shortlist candidate concepts for the LLM."""
    if (not USE_EMBEDDING_PREFILTER) or (concept_embeds is None) or (len(concepts) <= k):
        return concepts if len(concepts) <= k else concepts[:k]

    q_vec = embed_texts([question_text], model=EMBED_MODEL, batch_size=1)
    q_vec = _l2_normalize(q_vec)[0]

    scores = concept_embeds @ q_vec  # cosine similarity (both normalized)
    top_idx = np.argsort(-scores)[:k]
    return [concepts[i] for i in top_idx]

# ----------------------------------------------------------------------------
# STEP 7: LLM tagging (Structured Outputs JSON schema) ‚Äî FIXED ROOT OBJECT
# ----------------------------------------------------------------------------
def tag_question(
    question_text: str,
    candidate_concepts: List[str],
    num_concepts: int = NUM_CONCEPTS,
    min_confidence: float = MIN_CONFIDENCE
) -> List[Dict[str, Any]]:
    """
    Tag a question with concepts using Responses API + JSON schema.
    Returns a list of tag objects (length 0..num_concepts).
    """
    safe_concepts = [c for c in candidate_concepts if isinstance(c, str) and c.strip()]
    if not safe_concepts:
        return []

    # Root MUST be an object, so we use { "concepts": [ ... ] }
    min_items = 1 if num_concepts <= 1 else 2
    schema = {
        "type": "object",
        "additionalProperties": False,
        "properties": {
            "concepts": {
                "type": "array",
                "minItems": min_items,
                "maxItems": max(1, num_concepts),
                "items": {
                    "type": "object",
                    "additionalProperties": False,
                    "properties": {
                        "concept_name": {"type": "string"},
                        "confidence": {"type": "number", "minimum": 0, "maximum": 1},
                        "explanation": {"type": "string"}
                    },
                    "required": ["concept_name", "confidence", "explanation"]
                }
            }
        },
        "required": ["concepts"]
    }

    full_prompt = (
        f"{PROMPT}\n\n"
        f"SQL Question/Problem:\n{question_text}\n\n"
        "Available Concepts in Knowledge Graph:\n" +
        "\n".join(f"- {c}" for c in safe_concepts)
    )

    resp = client.responses.create(
        model=TAGGER_MODEL,
        input=[
            {"role": "system", "content": "You are a precise educational concept tagging assistant."},
            {"role": "user", "content": full_prompt},
        ],
        text={
            "format": {
                "type": "json_schema",
                "name": "concept_tagging",
                "schema": schema,
                "strict": True
            }
        }
    )

    raw = resp.output_text.strip()
    try:
        parsed = json.loads(raw)
    except Exception:
        print("‚ùå Failed to parse model output as JSON.")
        return []

    tags = parsed.get("concepts", [])
    if not isinstance(tags, list):
        return []

    # Filter by confidence
    tags = [t for t in tags if float(t.get("confidence", 0)) >= min_confidence]

    # Deduplicate by concept_name, preserve order
    seen = set()
    out = []
    for t in tags:
        name = t.get("concept_name")
        if isinstance(name, str) and name.strip() and name not in seen:
            seen.add(name)
            out.append(t)

    return out[:max(1, num_concepts)]

# ----------------------------------------------------------------------------
# STEP 8: Batch processing
# ----------------------------------------------------------------------------
def process_questions_batch(
    questions_data: List[Dict[str, Any]],
    kg_jsonl_path: str,
    output_path: str,
    num_concepts: int = NUM_CONCEPTS,
    min_confidence: float = MIN_CONFIDENCE
) -> Dict[str, Any]:

    print("\n" + "=" * 80)
    print("BATCH QUESTION ‚Üí CONCEPT TAGGING")
    print("=" * 80 + "\n")

    print(f"üìö Extracting concepts from: {kg_jsonl_path}")
    concepts = extract_concepts_from_kg(kg_jsonl_path)

    concept_embeds = None
    if USE_EMBEDDING_PREFILTER and len(concepts) > CANDIDATE_POOL_SIZE:
        concept_embeds = load_or_build_concept_embeddings(concepts, cache_path=EMBED_CACHE_PATH)

    print(f"\nFound {len(concepts)} total KG concepts")
    if USE_EMBEDDING_PREFILTER and concept_embeds is not None:
        print(f"Using embeddings prefilter ‚Üí LLM sees top {CANDIDATE_POOL_SIZE} candidates per question")
    else:
        print("No embeddings prefilter ‚Üí LLM sees full (or truncated) concept list")

    tagged_questions = []

    for i, q in enumerate(questions_data, 1):
        qid = q.get("question_id", f"Q{i}")
        qtext = (q.get("question_text", "") or "").strip()

        print(f"\n[{i}/{len(questions_data)}] Processing {qid}...")
        print(f"   Question: {qtext[:90].replace('\\n',' ')}...")

        if not qtext:
            tags = []
            print("   ‚ö†Ô∏è  No question text ‚Üí skipped tagging")
        else:
            candidates = select_candidate_concepts(
                question_text=qtext,
                concepts=concepts,
                concept_embeds=concept_embeds,
                k=CANDIDATE_POOL_SIZE
            )
            tags = tag_question(
                question_text=qtext,
                candidate_concepts=candidates,
                num_concepts=num_concepts,
                min_confidence=min_confidence
            )

            if tags:
                print("   ‚úÖ Tagged with:")
                for t in tags:
                    print(f"      - {t['concept_name']} (conf: {t['confidence']:.2f})")
            else:
                print("   ‚ö†Ô∏è  No concepts assigned (below confidence threshold)")

        record = {
            "question_id": qid,
            "question_text": qtext,
        }

        if num_concepts <= 1:
            record["primary_concept"] = tags[0] if tags else None
        else:
            record["concepts"] = tags

        tagged_questions.append(record)

    output_data = {
        "total_questions": len(tagged_questions),
        "num_concepts_per_question": num_concepts,
        "min_confidence": min_confidence,
        "questions": tagged_questions
    }

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)

    print("\n" + "=" * 80)
    print("‚úÖ SUMMARY")
    print("=" * 80)
    tagged_count = sum(
        1 for r in tagged_questions
        if (r.get("primary_concept") is not None) or (len(r.get("concepts", [])) > 0)
    )
    print(f"Total questions processed: {len(tagged_questions)}")
    print(f"Successfully tagged: {tagged_count}/{len(tagged_questions)}")
    print(f"üíæ Results saved to: {output_path}")
    print("=" * 80 + "\n")

    return output_data

# ----------------------------------------------------------------------------
# STEP 9: Define your questions
# ----------------------------------------------------------------------------
questions = [
    {
        "question_id": "Q1",
        "question_text": """.""",
    },
    {
        "question_id": "Q2",
        "question_text": """""",
    },
    {
        "question_id": "Q3",
        "question_text": """""",
    }
]

# ----------------------------------------------------------------------------
# STEP 10: Run tagging
# ----------------------------------------------------------------------------
result = process_questions_batch(
    questions_data=questions,
    kg_jsonl_path="MAIN_sql_qwen14b.jsonl",
    output_path="tagged_questions.json",
    num_concepts=NUM_CONCEPTS,
    min_confidence=MIN_CONFIDENCE
)

# ----------------------------------------------------------------------------
# STEP 11: View results
# ----------------------------------------------------------------------------
print("\nüìä DETAILED RESULTS:\n")
for q in result["questions"]:
    print("=" * 80)
    print(f"Question: {q['question_id']}")
    print("=" * 80)

    if NUM_CONCEPTS <= 1:
        pc = q.get("primary_concept")
        if pc:
            print(f"‚úÖ Concept: {pc['concept_name']}")
            print(f"   Confidence: {pc['confidence']:.2f}")
            print(f"   Reason: {pc['explanation']}")
        else:
            print("‚ö†Ô∏è  No concept assigned")
    else:
        concepts = q.get("concepts", [])
        if concepts:
            print(f"‚úÖ Tagged with {len(concepts)} concept(s):")
            for i, c in enumerate(concepts, 1):
                print(f"  {i}. {c['concept_name']} (confidence: {c['confidence']:.2f})")
                print(f"     ‚Üí {c['explanation']}")
        else:
            print("‚ö†Ô∏è  No concepts assigned")

    print()