## Original dataset preparation and processing with PMD

In [None]:
from __future__ import annotations
import io
import json
import os
import re
import shutil
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Set
import javalang
from unidiff import PatchSet
import xml.etree.ElementTree as ET

In [None]:
BASE_DIR = Path(r"CodeReviewer")

RAW_DIR      = BASE_DIR / "Dataset" / "OriginalDataset"
TEMP_DIR     = BASE_DIR / "DataPreparation" / "Temp"
TRAIN_RAW    = RAW_DIR / "cls-train-chunk-0.jsonl"
TEST_RAW     = RAW_DIR / "cls-test.jsonl"
VALID_RAW     = RAW_DIR / "cls-valid.jsonl"

MAX_CODE_LEN = 15_000    # skip very large files
PMD_EXE      = "pmd.bat"
PMD_RULESETS = ",".join([
    "category/java/errorprone.xml",
    "category/java/multithreading.xml",
    "category/java/performance.xml",
    "category/java/codestyle.xml",
    "category/java/bestpractices.xml",
    "category/java/design.xml",
    "category/java/security.xml",
])

 Helpers: I/O utilities

In [None]:
def ensure_dir(p: Path | str) -> Path:
    p = Path(p)
    p.mkdir(parents=True, exist_ok=True)
    return p

def write_jsonl(out_path: Path, record: dict) -> None:
    with out_path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")

 Stage A – Filter valid Java entries

In [None]:
def filter_java_entries(
    input_path: Path,
    output_path: Path,
) -> None:
    total = sum(1 for line in input_path.open(encoding="utf-8") if line.strip())
    checked = valid = 0

    with input_path.open(encoding="utf-8") as inp, \
         output_path.open("w", encoding="utf-8") as out:

        for raw in inp:
            if not raw.strip():
                continue
            checked += 1
            if checked % 500 == 0 or checked == total:
                print(f"Filter progress: {checked}/{total}", end="\r")

            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                continue

            lang, code = data.get("lang"), data.get("oldf")
            if not code:
                continue

            if lang == "java":
                out.write(json.dumps(data, ensure_ascii=False) + "\n")
                valid += 1
                continue

            try:
                javalang.parse.parse(code)
            except (javalang.parser.JavaSyntaxError,
                    javalang.tokenizer.LexerError,
                    IndexError,
                    StopIteration):
                continue

            out.write(json.dumps(data, ensure_ascii=False) + "\n")
            valid += 1

    print(f"\n✅ Saved {valid} valid Java entries → {output_path.name}")

 Stage B – Apply unified-diff patches

In [None]:
def apply_patch(original: str, patch_text: str) -> str:
    if not patch_text.startswith("---"):
        patch_text = "--- a/file.java\n+++ b/file.java\n" + patch_text

    patch = PatchSet(io.StringIO(patch_text))
    lines = original.splitlines()
    patched = lines.copy()

    for pf in patch:
        for hunk in reversed(pf):
            start = hunk.source_start - 1
            end   = start + hunk.source_length
            del patched[start:end]

            new_lines = [
                l.value.rstrip("\n") for l in hunk
                if l.is_added or l.is_context
            ]
            for l in reversed(new_lines):
                patched.insert(start, l)

    return "\n".join(patched)

def add_patched_code(
    input_path: Path,
    output_path: Path,
    max_len: int = MAX_CODE_LEN,
) -> None:
    ensure_dir(output_path.parent)
    kept = skipped = 0

    with input_path.open(encoding="utf-8") as inp, \
         output_path.open("w", encoding="utf-8") as out:

        for raw in inp:
            data = json.loads(raw)
            if data.get("lang") != "java":
                continue

            patched = apply_patch(data.get("oldf", ""), data.get("patch", ""))
            if len(patched) > max_len:
                skipped += 1
                continue

            data["code"] = patched
            out.write(json.dumps(data, ensure_ascii=False) + "\n")
            kept += 1

    print(f"✅ Patched {kept} entries  |  Skipped (too long): {skipped}")

Stage C – Basic dataset stats

In [None]:
def show_longest_entry(jsonl_path: Path) -> None:
    longest, max_len = None, 0
    for raw in jsonl_path.open(encoding="utf-8"):
        data = json.loads(raw)
        code_len = len(data.get("code", ""))
        if code_len > max_len:
            max_len, longest = code_len, data
    if longest:
        print(f"▶ Longest ID: {longest['id']}  |  {max_len} chars")
        print(f"Message: {longest.get('msg')}\n")
        print(longest["code"][:2000], "..." if max_len > 2000 else "")

 Stage D – Run PMD in parallel

In [None]:
def run_pmd(java_path: Path, report_path: Path) -> int:
    cmd = [
        PMD_EXE, "check",
        "-d", str(java_path),
        "-R", PMD_RULESETS,
        "-f", "xml",
        "-r", str(report_path),
        "--minimum-priority", "3",
    ]
    proc = subprocess.run(cmd,
                          stdout=subprocess.DEVNULL,
                          stderr=subprocess.PIPE,
                          shell=False)
    # PMD returns 4 when only low-priority rules fail, treat as success
    return 0 if proc.returncode in (0, 4) else proc.returncode

def pmd_batch(mode: str, jsonl_path: Path, work_dir: Path) -> None:
    mode = mode.capitalize()
    code_dir    = ensure_dir(work_dir / mode / "ExtractedCode")
    reports_dir = ensure_dir(work_dir / mode / "Reports")
    failed_dir  = ensure_dir(reports_dir / "Failed")

    stats = {"total": 0, "processed": 0, "errors": 0}
    lock  = threading.Lock()
    futures = []

    def _task(entry: dict):
        nonlocal stats
        eid, code = entry["id"], entry["code"]
        java_f = code_dir / f"{eid}.java"
        report = reports_dir / f"{eid}.xml"

        java_f.write_text(code, encoding="utf-8")
        rc = run_pmd(java_f, report)
        with lock:
            stats["processed"] += 1 if rc == 0 else 0
            stats["errors"]    += 0 if rc == 0 else 1
        if rc:
            shutil.copy(java_f, failed_dir / java_f.name)
        return eid, rc

    with ThreadPoolExecutor(max_workers=os.cpu_count()*4) as ex, \
         jsonl_path.open(encoding="utf-8") as inp:

        for raw in inp:
            entry = json.loads(raw)
            stats["total"] += 1
            futures.append(ex.submit(_task, entry))

        for f in as_completed(futures):
            eid, rc = f.result()
            status = "✅" if rc == 0 else "❌"
            with lock:
                done = stats["processed"] + stats["errors"]
                print(f"{status} {eid} ({done}/{stats['total']})", end="\r")

    print(f"\n{mode} PMD: {stats}")

 Stage E – Attach PMD warnings to diffs

In [None]:
HUNK_HEADER = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@")

def changed_lines(patch: str) -> Set[int]:
    changed, new_line = set(), None
    for l in patch.splitlines():
        m = HUNK_HEADER.match(l)
        if m:
            new_line = int(m.group(1))
            continue
        if new_line is None or l.startswith(("---", "+++")):
            continue
        if l.startswith('+'):
            changed.add(new_line)
            new_line += 1
        elif l.startswith('-'):
            continue
        else:
            new_line += 1
    return changed

def relevant_warnings(xml_path: Path, changed: Set[int]) -> List[str]:
    if not xml_path.is_file():
        return []
    ns = {"p": "http://pmd.sourceforge.net/report/2.0.0"}
    try:
        root = ET.parse(xml_path).getroot()
    except ET.ParseError:
        return []

    warnings: List[str] = []
    for v in root.findall(".//p:violation", ns):
        beg, end = int(v.attrib["beginline"]), int(v.attrib["endline"])
        if any(l in changed for l in range(beg, end+1)):
            txt = (v.text or "").strip()
            warnings.append(f"{beg}-{end} | {v.attrib['rule']}: {txt}")
    return warnings

def augment_with_pmd(mode: str, work_dir: Path) -> None:
    mode = mode.capitalize()
    dataset   = work_dir / mode / f"{mode.lower()}-java-with-code.jsonl"
    reports   = work_dir / mode / "Reports"
    out_path  = work_dir / mode / f"{mode.lower()}-java-with-code-with-pmd.jsonl"

    total = 0
    with dataset.open(encoding="utf-8") as inp, out_path.open("w", encoding="utf-8") as out:
        for raw in inp:
            entry = json.loads(raw)
            diff_lines = changed_lines(entry.get("patch", ""))
            xml_file   = reports / f"{entry['id']}.xml"
            warnings   = relevant_warnings(xml_file, diff_lines)

            out_rec = {
                "id":          entry["id"],
                "realReview":  entry.get("msg", ""),
                "pmdWarnings": warnings,
                "patch":       entry["patch"],
                "code":        entry["code"],
            }
            out.write(json.dumps(out_rec, ensure_ascii=False) + "\n")
            total += 1

    print(f"✅ {mode}: PMD-augmented file → {out_path.name}  ({total} entries)")

 Driver functions

In [None]:
def prepare_datasets():
    # Stage A
    java_train = TEMP_DIR / "Train" / "train-java.jsonl"
    java_test  = TEMP_DIR / "Test"  / "test-java.jsonl"
    java_valid  = TEMP_DIR / "Valid"  / "valid-java.jsonl"
    ensure_dir(java_train.parent)
    filter_java_entries(TRAIN_RAW, java_train)
    filter_java_entries(TEST_RAW,  java_test)
    filter_java_entries(VALID_RAW,  java_valid)

    # Stage B
    patched_train = TEMP_DIR / "Train" / "train-java-with-code.jsonl"
    patched_test  = TEMP_DIR / "Test"  / "test-java-with-code.jsonl"
    patched_valid  = TEMP_DIR / "Valid"  / "valid-java-with-code.jsonl"
    add_patched_code(java_train, patched_train)
    add_patched_code(java_test,  patched_test)
    add_patched_code(java_valid,  patched_valid)

    # Optional quick stats
    show_longest_entry(patched_train)

    # Stage D
    pmd_batch("Train", patched_train, TEMP_DIR)
    pmd_batch("Test",  patched_test,  TEMP_DIR)
    pmd_batch("Valid", patched_valid, TEMP_DIR)

    # Stage E
    augment_with_pmd("Train", TEMP_DIR)
    augment_with_pmd("Test",  TEMP_DIR)
    augment_with_pmd("Valid",  TEMP_DIR)

In [None]:
prepare_datasets()