In [None]:
import os
import re
import ast
import sys
import uuid
import json
import textwrap
import subprocess
from pathlib import Path
from dataclasses import dataclass
from typing import List, Protocol, Tuple, Dict, Optional

from dotenv import load_dotenv
from openai import OpenAI
from openai import BadRequestError as _OpenAIBadRequest
import gradio as gr

load_dotenv(override=True)

# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
GROQ_BASE   = "https://api.groq.com/openai/v1"

# --- API Keys (add these in your .env) ---
openai_api_key = os.getenv("OPENAI_API_KEY")   # OpenAI
google_api_key = os.getenv("GOOGLE_API_KEY")   # Gemini
groq_api_key   = os.getenv("GROQ_API_KEY")     # Groq

# --- Clients ---
openai_client = OpenAI()  # OpenAI default (reads OPENAI_API_KEY)
gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None
groq_client   = OpenAI(api_key=groq_api_key,   base_url=GROQ_BASE)   if groq_api_key   else None

# --- Model registry: label -> { client, model } ---
MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}

def _register(label: str, client: Optional[OpenAI], model_id: str):
    """Add a model to the registry only if its client is configured."""
    if client is not None:
        MODEL_REGISTRY[label] = {"client": client, "model": model_id}

# OpenAI
_register("OpenAI • GPT-5",        openai_client, "gpt-5")
_register("OpenAI • GPT-5 Nano",   openai_client, "gpt-5-nano")
_register("OpenAI • GPT-4o-mini",  openai_client, "gpt-4o-mini")

# Gemini (Google)
_register("Gemini • 2.5 Pro",      gemini_client, "gemini-2.5-pro")
_register("Gemini • 2.5 Flash",    gemini_client, "gemini-2.5-flash")

# Groq
_register("Groq • Llama 3.1 8B",   groq_client,   "llama-3.1-8b-instant")
_register("Groq • Llama 3.1 70B",  groq_client,   "llama-3.1-70b-versatile")
_register("Groq • GPT-OSS 20B",    groq_client,   "gpt-oss-20b")
_register("Groq • GPT-OSS 120B",   groq_client,   "gpt-oss-120b")

DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)

print(f"Providers configured → OpenAI:{bool(openai_api_key)}  Gemini:{bool(google_api_key)}  Groq:{bool(groq_api_key)}")
print("Models available     →", ", ".join(MODEL_REGISTRY.keys()) or "None (add API keys in .env)")


In [None]:
class CompletionClient(Protocol):
    """Any LLM client provides a .complete() method using a registry label."""
    def complete(self, *, model_label: str, system: str, user: str) -> str: ...


def _extract_code_or_text(s: str) -> str:
    """Prefer fenced python if present; otherwise return raw text."""
    m = re.search(r"```(?:python)?\s*(.*?)```", s, flags=re.S | re.I)
    return m.group(1).strip() if m else s.strip()


class MultiModelChatClient:
    """Routes requests to the right provider/client based on model label."""
    def __init__(self, registry: Dict[str, Dict[str, object]]):
        self._registry = registry

    def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:
        params = {
            "model": model_id,
            "messages": [
                {"role": "system", "content": system},
                {"role": "user",   "content": user},
            ],
        }
        resp = client.chat.completions.create(**params)  # do NOT send temperature for strict providers
        text = (resp.choices[0].message.content or "").strip()
        return _extract_code_or_text(text)

    def complete(self, *, model_label: str, system: str, user: str) -> str:
        if model_label not in self._registry:
            raise ValueError(f"Unknown model label: {model_label}")
        info   = self._registry[model_label]
        client = info["client"]
        model  = info["model"]
        try:
            return self._call(client=client, model_id=str(model), system=system, user=user)
        except _OpenAIBadRequest as e:
            # Providers may reject stray params; we don't send any, but retry anyway.
            if "temperature" in str(e).lower():
                return self._call(client=client, model_id=str(model), system=system, user=user)
            raise


In [None]:
def _extract_code_or_text(s: str) -> str:
    """Prefer fenced python if present; otherwise return raw text."""
    m = re.search(r"```(?:python)?\s*(.*?)```", s, flags=re.S | re.I)
    return m.group(1).strip() if m else s.strip()


def _ensure_header_and_import(code: str, module_name: str) -> str:
    """Make sure tests import the module and pytest; keep output minimal."""
    code = code.strip()
    needs_pytest = "import pytest" not in code
    needs_import = f"import {module_name}" not in code and f"import {module_name} as mod" not in code

    header_lines = []
    if needs_pytest:
        header_lines.append("import pytest")
    if needs_import:
        header_lines.append(f"import {module_name} as mod")

    if header_lines:
        code = "\n".join(header_lines) + "\n\n" + code
    return code


In [None]:
@dataclass(frozen=True)
class SymbolInfo:
    kind: str      # "function" | "class" | "method"
    name: str
    signature: str
    lineno: int


class PublicAPIExtractor:
    """Extract a small 'public API' summary from a Python module."""
    def extract(self, source: str) -> List[SymbolInfo]:
        tree = ast.parse(source)
        out: List[SymbolInfo] = []

        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and not node.name.startswith("_"):
                out.append(SymbolInfo("function", node.name, self._sig(node), node.lineno))
            elif isinstance(node, ast.ClassDef) and not node.name.startswith("_"):
                out.append(SymbolInfo("class", node.name, node.name, node.lineno))
                for sub in node.body:
                    if isinstance(sub, ast.FunctionDef) and not sub.name.startswith("_"):
                        out.append(SymbolInfo("method",
                                              f"{node.name}.{sub.name}",
                                              self._sig(sub),
                                              sub.lineno))
        return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))

    def _sig(self, fn: ast.FunctionDef) -> str:
        args = [a.arg for a in fn.args.args]
        if fn.args.vararg:
            args.append("*" + fn.args.vararg.arg)
        args.extend(a.arg + "=?" for a in fn.args.kwonlyargs)
        if fn.args.kwarg:
            args.append("**" + fn.args.kwarg.arg)
        ret = ""
        if fn.returns is not None:
            try:
                ret = f" -> {ast.unparse(fn.returns)}"
            except Exception:
                pass
        return f"def {fn.name}({', '.join(args)}){ret}:"


In [None]:
class PromptBuilder:
    """Builds concise, deterministic prompts for pytest generation."""
    SYSTEM = (
        "You are a senior Python engineer. Produce a single, self-contained pytest file.\n"
        "Rules:\n"
        "- Output only Python test code (no prose, no markdown fences).\n"
        "- Use plain pytest tests (functions), no classes unless unavoidable.\n"
        "- Deterministic: avoid network/IO; seed randomness if used.\n"
        "- Import the target module by module name.\n"
        "- Create a minimal test covering every public function and method.\n"
        "- Prefer straightforward, fast assertions over exhaustive checks.\n"
    )

    def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:
        summary = "\n".join(f"- {s.kind:<6}  {s.signature}" for s in symbols) or "- (no public symbols)"
        return textwrap.dedent(f"""
        Create pytest tests for module `{module_name}`.

        Public API Summary:
        {summary}

        Constraints:
        - Import as: `import {module_name} as mod`
        - Keep tests tiny, fast, and deterministic.

        Full module source (for reference):
        # --- BEGIN SOURCE {module_name}.py ---
        {source}
        # --- END SOURCE ---
        """).strip()


In [None]:
class TestGenerator:
    """Orchestrates extraction, prompt, model call, and final polish."""
    def __init__(self, llm: CompletionClient):
        self._llm = llm
        self._extractor = PublicAPIExtractor()
        self._prompts = PromptBuilder()

    def generate_tests(self, model_label: str, module_name: str, source: str) -> str:
        symbols = self._extractor.extract(source)
        user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)
        raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)
        return _ensure_header_and_import(raw, module_name)


In [None]:
LLM = OpenAIChatClient(OPENAI_CLIENT)
SERVICE = TestGenerator(LLM, TESTGEN_MODEL)

def build_module_name_from_path(path: str) -> str:
    p = Path(path)
    return p.stem

def generate_from_code(module_name: str, code: str, save: bool, out_dir: str) -> tuple[str, str]:
    if not module_name.strip():
        return "", "❌ Please provide a module name."
    if not code.strip():
        return "", "❌ Please paste some Python code."

    tests_code = SERVICE.generate_tests(module_name=module_name.strip(), source=code)
    saved = ""
    if save:
        out = Path(out_dir or "tests")
        out.mkdir(parents=True, exist_ok=True)
        out_path = out / f"test_{module_name}.py"
        out_path.write_text(tests_code, encoding="utf-8")
        saved = f"✅ Saved to {out_path}"
    return tests_code, saved


def generate_from_file(file_obj, save: bool, out_dir: str) -> tuple[str, str]:
    if file_obj is None:
        return "", "❌ Please upload a .py file."
    code = file_obj.decode("utf-8")
    module_name = build_module_name_from_path("uploaded_module.py")
    return generate_from_code(module_name, code, save, out_dir)


In [None]:
with gr.Blocks(title="Simple PyTest Generator") as ui:
    gr.Markdown("## 🧪 Simple PyTest Generator (Week 4 • Community Contribution)\n"
                "Generate **minimal, deterministic** pytest tests from a Python module using a Frontier model.")

    with gr.Tab("Paste Code"):
        with gr.Row():
            module_name = gr.Textbox(label="Module name (used in `import <name> as mod`)", value="mymodule")
        code_in = gr.Code(label="Python module code", language="python", lines=22)
        with gr.Row():
            save_cb = gr.Checkbox(label="Save to /tests", value=True)
            out_dir = gr.Textbox(label="Output folder", value="tests")
        gen_btn = gr.Button("Generate tests", variant="primary")
        with gr.Row():
            tests_out = gr.Code(label="Generated tests (pytest)", language="python", lines=20)
        status = gr.Markdown()

        def _on_gen(name, code, save, outdir):
            tests, msg = generate_from_code(name, code, save, outdir)
            return tests, (msg or "✅ Done")

        gen_btn.click(_on_gen, inputs=[module_name, code_in, save_cb, out_dir], outputs=[tests_out, status])

    with gr.Tab("Upload .py"):
        upload = gr.File(file_types=[".py"], label="Upload a Python module")
        with gr.Row():
            save_cb2 = gr.Checkbox(label="Save to /tests", value=True)
            out_dir2 = gr.Textbox(label="Output folder", value="tests")
        gen_btn2 = gr.Button("Generate tests from file")
        tests_out2 = gr.Code(label="Generated tests (pytest)", language="python", lines=20)
        status2 = gr.Markdown()

        def _on_gen_file(f, save, outdir):
            tests, msg = generate_from_file(f.read() if f else None, save, outdir)
            return tests, (msg or "✅ Done")

        gen_btn2.click(_on_gen_file, inputs=[upload, save_cb2, out_dir2], outputs=[tests_out2, status2])

ui.launch(inbrowser=True)
