## Import libraries

In [None]:
!pip install genlm-control

Collecting genlm-control
  Downloading genlm_control-0.2.12-py3-none-any.whl.metadata (8.8 kB)
Collecting genlm-grammar>=0.2.0 (from genlm-control)
  Downloading genlm_grammar-0.2.0-py3-none-any.whl.metadata (3.1 kB)
Collecting genlm-backend>=0.1.1 (from genlm-control)
  Downloading genlm_backend-0.1.7-py3-none-any.whl.metadata (6.1 kB)
Collecting llamppl (from genlm-control)
  Downloading llamppl-0.2.2-py3-none-any.whl.metadata (6.7 kB)
Collecting arsenal>=3.1.3 (from genlm-control)
  Downloading arsenal-3.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting json-stream (from genlm-control)
  Downloading json_stream-2.4.1-py3-none-any.whl.metadata (26 kB)
Collecting gprof2dot (from arsenal>=3.1.3->genlm-control)
  Downloading gprof2dot-2025.4.14-py3-none-any.whl.metadata (19 kB)
Collecting path.py (from arsenal>=3.1.3->genlm-control)
  Downloading path.py-12.5.0-py3-none-any.whl.metadata (1.3 kB)
Collecting bitsandbytes (from genlm-backend>=0.1.1

In [None]:
!pip install pyright

Collecting pyright
  Downloading pyright-1.1.407-py3-none-any.whl.metadata (6.6 kB)
Downloading pyright-1.1.407-py3-none-any.whl (6.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyright
Successfully installed pyright-1.1.407


In [None]:
pip install -U dspy

Collecting dspy
  Downloading dspy-3.0.4-py3-none-any.whl.metadata (8.4 kB)
Collecting backoff>=2.2 (from dspy)
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Collecting optuna>=3.4.0 (from dspy)
  Downloading optuna-4.6.0-py3-none-any.whl.metadata (17 kB)
Collecting magicattr>=0.1.6 (from dspy)
  Downloading magicattr-0.1.6-py2.py3-none-any.whl.metadata (3.2 kB)
Collecting litellm>=1.64.0 (from dspy)
  Downloading litellm-1.80.11-py3-none-any.whl.metadata (29 kB)
Collecting json-repair>=0.30.0 (from dspy)
  Downloading json_repair-0.54.3-py3-none-any.whl.metadata (12 kB)
Collecting asyncer==0.0.8 (from dspy)
  Downloading asyncer-0.0.8-py3-none-any.whl.metadata (6.7 kB)
Collecting gepa==0.0.17 (from gepa[dspy]==0.0.17->dspy)
  Downloading gepa-0.0.17-py3-none-any.whl.metadata (26 kB)
Collecting fastuuid>=0.13.0 (from litellm>=1.64.0->dspy)
  Downloading fastuuid-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Collecting grpcio<1.68.0,

In [None]:
pip install transformers accelerate

In [None]:
import ast
import asyncio
import logging
import json
import torch
import re
import dspy
from typing import List, Dict
from genlm.control.potential.base import Potential
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
logger = logging.getLogger(__name__)

## LSP potential updated with diagnostic messages function

In [None]:
class LSP(Potential):
    def __init__(
        self,
        lsp_command: list[str],
    ):
        super().__init__(vocabulary=list(range(256)))
        self.lsp_command = lsp_command
        self.lsp_process = None
        self.request_id = 1
        self.stderr_task = None
        logger.info(
            f"LSPPotential initialized with command: {self.lsp_command}"
        )

    async def _read_stderr(self):
        """Reads from the LSP server's stderr and logs it."""
        while (
            self.lsp_process
            and self.lsp_process.stderr
            and not self.lsp_process.stderr.at_eof()
        ):
            try:
                line = await self.lsp_process.stderr.readline()
                if line:
                    logger.error(f"[LSP stderr] {line.decode().strip()}")
            except Exception as e:
                logger.error(f"Error reading LSP stderr: {e}")
                break

    async def _lint(self, code: str) -> float:
        """Helper method to initialize server if needed and get diagnostics."""
        if not self.lsp_process or self.lsp_process.returncode is not None:
            await self.initialize_server()

        diagnostics = await self.get_diagnostics(code)
        for d in diagnostics:
            if d.get("severity") == 1:  # 1 is for Error
                logger.warning(f"LSP Error found: {d['message']}")
                return float("-inf")  # Penalize heavily

        return 0.0  # No errors found

    async def start_server(
        self,
    ):
        """Starts the LSP server as a subprocess."""

        if self.lsp_process and self.lsp_process.returncode is None:
            logger.info("LSP server already running.")
            return

        try:
            logger.info(
                f"Starting LSP server with command: {' '.join(self.lsp_command)}"
            )
            self.lsp_process = await asyncio.create_subprocess_exec(
                *self.lsp_command,
                stdin=asyncio.subprocess.PIPE,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            logger.info(f"LSP server started with PID: {self.lsp_process.pid}")
            self.stderr_task = asyncio.create_task(self._read_stderr())
        except FileNotFoundError:
            logger.error(
                f"LSP command not found: {self.lsp_command[0]}. Please ensure it's installed and in your PATH."
            )
            raise
        except Exception as e:
            logger.error(f"Failed to start LSP server: {e}")
            raise

    async def _send_message(
        self,
        message: dict,
    ):
        if not self.lsp_process or not self.lsp_process.stdin:
            raise ConnectionError("LSP server is not running.")

        json_message = json.dumps(message)
        body = json_message.encode("utf-8")
        header = f"Content-Length: {len(body)}\r\n\r\n".encode("utf-8")
        self.lsp_process.stdin.write(header + body)
        await self.lsp_process.stdin.drain()
        logger.debug(f"Sent: {json_message}")

    async def _read_message(
        self,
    ) -> dict:
        if not self.lsp_process or not self.lsp_process.stdout:
            raise ConnectionError("LSP server is not running.")
        header_bytes = await self.lsp_process.stdout.readuntil(b"\r\n\r\n")
        header = header_bytes.decode("utf-8")
        content_length_str = header.strip().split(": ")[1]
        content_length = int(content_length_str)
        body_bytes = await self.lsp_process.stdout.readexactly(content_length)
        body = body_bytes.decode("utf-8")
        logger.debug(f"Received: {body}")
        return json.loads(body)

    async def initialize_server(
        self,
    ):
        await self.start_server()
        init_request = {
            "jsonrpc": "2.0",
            "id": self.request_id,
            "method": "initialize",
            "params": {
                "processId": None,
                "rootUri": None,
                "capabilities": {},
            },
        }
        await self._send_message(init_request)
        self.request_id += 1
        response = await self._read_message()
        if "error" in response:
            raise RuntimeError(
                f"LSP server failed to initialize: {response['error']}"
            )
        logger.info("Server initialized successfully.")
        await self._send_message(
            {
                "jsonrpc": "2.0",
                "method": "initialized",
                "params": {},
            }
        )
        logger.info("Sent 'initialized' notification.")

    async def get_diagnostics(
        self,
        code: str,
    ) -> list:
        doc_uri = "file:///temp_file.py"

        open_notification = {
            "jsonrpc": "2.0",
            "method": "textDocument/didOpen",
            "params": {
                "textDocument": {
                    "uri": doc_uri,
                    "languageId": "python",
                    "version": 1,
                    "text": code,
                }
            },
        }
        await self._send_message(open_notification)

        try:
            while True:
                message = await asyncio.wait_for(
                    self._read_message(), timeout=2.0
                )
                # Check if the message is the diagnostic notification we're waiting for
                if message.get("method") == "textDocument/publishDiagnostics":
                    if message["params"]["uri"] == doc_uri:
                        logger.info(
                            f"Received diagnostics for {doc_uri}: {message['params']['diagnostics']}"
                        )
                        # We got them! Return the diagnostics
                        return message["params"]["diagnostics"]
                    else:
                        logger.debug(
                            f"Ignoring diagnostics for other document: {message['params']['uri']}"
                        )
                else:
                    logger.debug(f"Ignoring LSP message: {message}")
        except asyncio.TimeoutError:
            logger.warning(f"Timed out waiting for diagnostics for {doc_uri}")
            return []

    async def prefix(
        self,
        context,
    ):
        try:
            code = context.decode("utf-8")
        except UnicodeDecodeError:
            return float("-inf")
        try:
            ast.parse(code)
        except (
            SyntaxError,
            IndentationError,
        ):
            return 0.0
        return await self._lint(code)

    async def complete(
        self,
        context,
    ):

        # The 'complete' method should lint the final output.
        # Its logic is very similar to 'prefix' but it doesn't need the ast.parse check.
        try:
            code = context.decode("utf-8")
        except UnicodeDecodeError:
            return float("-inf")
        if not self.lsp_process or self.lsp_process.returncode is not None:
            await self.initialize_server()
        return await self._lint(code)

    async def close(
        self,
    ):
        if not self.lsp_process or self.lsp_process.returncode is not None:
            return
        shutdown_req = {
            "jsonrpc": "2.0",
            "id": self.request_id,
            "method": "shutdown",
        }
        await self._send_message(shutdown_req)
        await self._read_message()  # Wait for shutdown response

        await self._send_message(
            {
                "jsonrpc": "2.0",
                "method": "exit",
            }
        )

        if self.stderr_task:
            self.stderr_task.cancel()
        try:
            await self.stderr_task
        except asyncio.CancelledError:
            pass

        await self.lsp_process.wait()
        logger.info("LSP server shut down gracefully.")
        self.lsp_process = None

    def __repr__(
        self,
    ):
        return f"LSP(lsp_command={self.lsp_command})"

    def spawn(
        self,
    ):
        return LSP(self.lsp_command)

    # Get the correct diagnostics info --------- linter error message

    async def diagnostic_messages(self, code: str):
            """
            Return normalized diagnostics with message, severity score,
            and code range.
            """

            if isinstance(code, bytes):
                code = code.decode("utf-8", errors="replace")

            if not self.lsp_process or self.lsp_process.returncode is not None:
                await self.initialize_server()

            raw_diagnostics = await self.get_diagnostics(code)

            normalized = []

            for d in raw_diagnostics:
                severity = d.get("severity", 3)

                normalized.append({
                    "message": d.get("message", "No error found"),
                    "severity": severity,
                    "range": {
                        "start": {
                            "line": d["range"]["start"]["line"],
                            "character": d["range"]["start"]["character"],
                        },
                        "end": {
                            "line": d["range"]["end"]["line"],
                            "character": d["range"]["end"]["character"],
                        },
                    }
                })

            return normalized


## Diagnostic message function test

In [None]:
lsp = LSP(["pyright-langserver", "--stdio"])

code = b"""def add(a: int, b: int) -> int:\n    return 'hello' """

normalized_diagnostics = await lsp.diagnostic_messages(code.decode("utf-8"))
print(normalized_diagnostics)

## DSPY

In [None]:
lm = dspy.LM("gemini/gemini-2.5-flash", api_key=API_KEY)
dspy.configure(lm=lm)

In [None]:
class DSPyRepair(dspy.Signature):
    faulty_code: str = dspy.InputField(default="", desc="Faulty code")
    diagnostics: str = dspy.InputField(default="", desc="Normalized diagnostics")
    answer: str = dspy.OutputField(default=None, desc="Corrected code")

    def repair_code(self, code: str, diagnostics: list[dict], N: int = 6) -> str:
        """
        Use DSPy Refine to repair code with automatic feedback.
        - code: faulty Python code
        - diagnostics: list of normalized diagnostics
        - N: number of refinement iterations
        """
        diagnostics_json = json.dumps(diagnostics)

        # Define reward function (optional: can integrate LSP scoring later)
        def reward_fn(args, pred: dspy.Prediction) -> float:
            # Simple reward: +1 if code parses, else 0
            try:
                import ast
                ast.parse(pred.answer)
                return 1.0
            except Exception:
                return 0.0

        # Refine module replaces single Predict call
        refine_module = dspy.Refine(
            module=dspy.Predict(
                "faulty_code, diagnostics -> answer"
            ),
            N=N,
            reward_fn=reward_fn,
            threshold=1.0  # stops early if reward meets threshold
        )

        # Run Refine
        result = refine_module(
            question=f"""
            You are a Python code repair assistant.
            Return ONLY a JSON object with key "answer" containing the corrected Python code as a string.
            DONOT change the code, only return the corrected code.
            Inputs:
            faulty_code: {code}
            diagnostics: {diagnostics_json}
            """
        )

        # Return the best corrected code
        return result.answer


## Test Cases

In [None]:
TEST_CASES = [
    {
        "name": "Correct, complete code",
        "code": b"import os\n\ndef my_func():\n    return os.getcwd()",
        "expected_score": 0.0,
    },
    {
        "name": "Code with an undefined variable error",
        "code": b"def my_func():\n    print(undeclared_variable)",
        "expected_score": float("-inf"),
    },
    {
        "name": "Incomplete code (open function signature)",
        "code": b"def my_func(",
        "expected_score": 0.0,
    },
    {
        "name": "Incomplete code (if statement with no body)",
        "code": b"def my_func(x):\n    if x > 10:",
        "expected_score": 0.0,
    },
    {
        "name": "Syntactically valid code with a type error",
        "code": b"def add(a: int, b: int) -> int:\n    return 'hello'",
        "expected_score": float("-inf"),
    },
    {
        "name": "Empty code",
        "code": b"",
        "expected_score": 0.0,
    },
    {
        "name": "Code with only whitespace",
        "code": b"    \n  ",
        "expected_score": 0.0,
    },
    {
        "name": "Code importing a non-existent module",
        "code": b"import a_module_that_does_not_exist",
        "expected_score": float("-inf"),
    },
]

## DSPy test

In [None]:
results = []
lsp = LSP(["pyright-langserver", "--stdio"])

for test in TEST_CASES:
  code_bytes = test["code"]
  code_str = code_bytes.decode() if isinstance(code_bytes, bytes) else code_bytes

  diagnostics = await lsp.diagnostic_messages(code_str)

  dspy_repair = DSPyRepair()
  corrected = dspy_repair.repair_code(code=code_str, diagnostics=diagnostics)

  results.append({
            "name": test["name"],
            "original_code": code_str,
            "dspy_code": corrected,
            "diagnostics": diagnostics
        })

  print(f"Test '{test['name']}' done.")

with open("dspy_test_results.json", "w") as f:
    json.dump(results, f, indent=4)

## Evaluation

In [None]:
class Evaluator:
    def __init__(self, lsp):
        self.lsp = lsp
        # Load Qwen model
        self.model_name = "Qwen/Qwen3-14B"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype="auto",
            device_map="auto"
        )

    def deterministic_score(self, diagnostics: List[Dict]) -> float:
        """
        Compute a deterministic score based on severity:
        - Error (1) = -10
        - Warning (2) = -5
        - Info (3) = -1
        - Hint (4) = 0
        Higher score is better (less severe issues).
        """
        score = 0.0
        severity_penalties = {1: -10, 2: -5, 3: -1, 4: 0}
        for d in diagnostics:
            sev = d.get("severity", 3)
            penalty = severity_penalties.get(sev, -1)
            score += penalty
        return score

    async def get_normalized_diagnostics(self, code: bytes) -> List[Dict]:
        code_str = code.decode("utf-8", errors="replace")
        normalized_diagnostics = await self.lsp.diagnostic_messages(code_str)
        return normalized_diagnostics

    def llm_judge(self, old_code: str, new_code: str, max_score: int = 5) -> float:
        """
        Evaluate the corrected code using Qwen.
        Scores from 0 to max_score based on:
        - Semantic correctness (does new_code fix the issues in old_code?)
        - Preservation of original logic (no unnecessary changes)
        - Readability and code quality
        """
        prompt = f"""
                      You are an expert Python code evaluator.

                      Original code:
                      {old_code}

                      Corrected code:
                      {new_code}

                      Score the corrected code from 0 (worst) to {max_score} (best) based on:
                      1. Semantic correctness: does it fix the errors?
                      2. Preservation of original logic: does it avoid unnecessary changes?
                      3. Readability and code quality: is it clear and Pythonic?

                      Return the numeric score with a short reasoning in one sentence. The example is given below
                      Example:
                      Semantic correctness : 4
                      Preservation of original logic : 5
                      Readability and code quality : 5
                      and one line of reasoning
                  """

        messages = [{"role": "user", "content": prompt}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=256
        )
        # Skip the prompt tokens
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        output_text = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()


        return output_text


In [None]:
evaluator = Evaluator(lsp)

## Runner

In [None]:
async def evaluate_test_cases(
    input_json_path: str,
    output_json_path: str,
    evaluator
):
    # Load test cases
    with open(input_json_path, "r", encoding="utf-8") as f:
        test_cases: List[Dict] = json.load(f)

    for case in test_cases:
        print(f"Evaluating: {case.get('name', 'unnamed test')}")

        # ---------------------------
        # Step 1: deterministic score BEFORE
        # ---------------------------
        diagnostics_before = case.get("diagnostics", [])
        case["deterministic_score_before"] = evaluator.deterministic_score(
            diagnostics_before
        )

        # ---------------------------
        # Step 2: deterministic score AFTER (normalized diagnostics)
        # ---------------------------
        dspy_code = case.get("dspy_code", "")
        normalized_diagnostics = await evaluator.get_normalized_diagnostics(
            dspy_code.encode("utf-8")
        )

        case["normalized_diagnostics"] = normalized_diagnostics
        case["deterministic_score_after"] = evaluator.deterministic_score(
            normalized_diagnostics
        )

        # ---------------------------
        # Step 3: LLM as Judge
        # ---------------------------
        original_code = case.get("original_code", "")
        case["llm_judge_score"] = evaluator.llm_judge(
            original_code=original_code,
            new_code=dspy_code
        )

    # Save updated JSON
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(test_cases, f, indent=2)

    print(f"Evaluation complete. Results saved to {output_json_path}")


In [None]:
lsp = LSP(["pyright-langserver", "--stdio"])
evaluator = Evaluator(lsp)

await evaluate_test_cases(
    input_json_path="dspy_test_results.json",
    output_json_path="test_cases_with_scores.json",
    evaluator=evaluator
)