In [1]:
# Cell 0: Environment setup for Python 3.11, no ipywidgets, and progress bars disabled
import sys
import os
import subprocess


In [None]:

%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Number of CUDA devices:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Current CUDA device:", torch.cuda.get_device_name(torch.cuda.current_device()))

In [None]:

# Cell 1: Install all required packages and JDK (Ubuntu-based)
%pip install --upgrade pip
%pip install pandas numpy transformers accelerate datasets sentencepiece bitsandbytes gdown
# %apt-get update && apt-get install -y openjdk-17-jdk


Note: you may need to restart the kernel to use updated packages.
Collecting accelerate
  Using cached accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Using cached datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting sentencepiece
  Using cached sentencepiece-0.2.1-cp311-cp311-win_amd64.whl.metadata (10 kB)
Collecting bitsandbytes
  Using cached bitsandbytes-0.49.0-py3-none-win_amd64.whl.metadata (10 kB)
Collecting torch>=2.0.0 (from accelerate)
  Using cached torch-2.9.1-cp311-cp311-win_amd64.whl.metadata (30 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Using cached pyarrow-22.0.0-cp311-cp311-win_amd64.whl.metadata (3.3 kB)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Using cached dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting httpx<1.0.0 (from datasets)
  Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.6.0-cp311-cp311-win_amd64.whl.metadata (13 kB)
Collecting multip



^C
Note: you may need to restart the kernel to use updated packages.


In [None]:

# Cell: Download Project_CodeNet_Java.parquet if missing
import os
from pathlib import Path
import sys
import subprocess

PARQUET_PATH = Path("Project_CodeNet_Java.parquet")
if not PARQUET_PATH.exists():
    print("Project_CodeNet_Java.parquet not found. Downloading from Google Drive...")
    try:
        import gdown
    except ImportError:
        print("gdown not found, installing via pip...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
        import gdown
    gdown.download(
        url="https://drive.google.com/uc?id=1rOGafVqqzl2JM7bvoS-2kkBK6PCbNFsZ",
        output=str(PARQUET_PATH),
        quiet=False
    )
    if PARQUET_PATH.exists():
        print("✓ Download complete.")
    else:
        raise RuntimeError("Failed to download Project_CodeNet_Java.parquet.")
else:
    print("✓ Project_CodeNet_Java.parquet already exists.")


In [None]:
# Cell 2: GPU and environment verification
import torch
import subprocess

def check_env():
    print("\n--- Environment Verification ---")
    # PyTorch version
    try:
        print(f"PyTorch version: {torch.__version__}")
    except Exception as e:
        print(f"PyTorch not installed: {e}")
    # CUDA availability
    cuda_available = torch.cuda.is_available()
    print(f"CUDA available: {cuda_available}")
    if cuda_available:
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU name: {torch.cuda.get_device_name(0)}")
    else:
        print("No CUDA GPU detected.")
    # Java version
    try:
        result = subprocess.run(["java", "-version"], capture_output=True, text=True)
        if result.returncode == 0:
            print("Java version:")
            print(result.stderr.strip() or result.stdout.strip())
        else:
            print("Java not found or failed to run.")
    except Exception as e:
        print(f"Java not available: {e}")
    print("--- End of Environment Verification ---\n")

check_env()


## 1. Setup and Configuration
Import all required libraries and set up the configuration for the hybrid generation process.

In [None]:
"""
Import all required libraries
"""
%pip install colorama jsonlines hf-transfer

import uuid
import time
import tempfile
import requests
from pathlib import Path
import jsonlines
import pandas as pd
from colorama import Fore, Style, init
from tqdm import tqdm
import re
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize colorama for Windows compatibility
init(autoreset=True)

print(f"{Fore.GREEN}✓ All libraries imported successfully")

In [None]:
"""
Configuration - Edit these paths and parameters
"""

# Path to CodeNet root directory
CODENET_ROOT = Path(r"Project_CodeNet_Java.parquet")


# Output paths
EASY_OUTPUT_PATH = Path("dataset/java_clones_easy_types.jsonl")
HARD_OUTPUT_PATH = Path("dataset/java_clones_hard_types.jsonl")
COMBINED_OUTPUT_PATH = Path("dataset/java_clones_10k.jsonl")
# Non-clone output paths
EASY_NONCLONES_OUTPUT_PATH = Path("dataset/java_nonclones_easy_types.jsonl")
HARD_NONCLONES_OUTPUT_PATH = Path("dataset/java_nonclones_hard_types.jsonl")
FINAL_DATASET_PATH = Path("dataset/java_complete_dataset.jsonl")
GENERATED_DIR = Path("generated")
SEEDS_DIR = Path("seeds")

# Execution settings
TIMEOUT_SECONDS = 30
MAX_PROBLEMS = None  # None for all problems, or int to limit

# Clone generation settings
TARGET_CLONES_PER_TYPE = 1  # Target per type - reasonable for testing, can increase later
# Non-clone generation settings
TARGET_NONCLONES_EASY = 1 # Target easy non-clones (simple algorithmic differences)
TARGET_NONCLONES_HARD = 1  # Target hard non-clones (different problem domains)
MAX_CLONES_PER_PROBLEM = 5
MAX_CYCLES = 50

# Model settings - HYBRID APPROACH
EASY_MODEL = "qwen2.5-coder:7b"       # Fast model for Type-1, Type-2
HARD_MODEL = "qwen2.5-coder:7b"
# EASY_MODEL = "codegemma:7b"       # Fast model for Type-1, Type-2
# HARD_MODEL = "codegemma:7b"  # Capable model for Type-3, Type-4

# Logging
VERBOSE = True

# Create directories
for directory in [GENERATED_DIR, SEEDS_DIR, EASY_OUTPUT_PATH.parent]:
    directory.mkdir(exist_ok=True, parents=True)

print(f"{Fore.GREEN}✓ Configuration loaded")
print(f"{Fore.CYAN}Easy Model (Type-1/2): {EASY_MODEL}")
print(f"{Fore.CYAN}Hard Model (Type-3/4): {HARD_MODEL}")
print(f"{Fore.CYAN}Target per type: {TARGET_CLONES_PER_TYPE}")
print(f"{Fore.CYAN}CodeNet Root: {CODENET_ROOT}")

In [None]:
def normalize_unicode_to_ascii(text):
    """Convert Unicode characters to ASCII equivalents."""
    replacements = {
        '\u201c': '"', '\u201d': '"', '\u2018': "'", '\u2019': "'",
        '\u201b': "'", '\u2013': '-', '\u2014': '-', '\u2015': '-',
        '\u00a0': ' ', '\u2009': ' ', '\u200a': ' ', '\u2026': '...',
        '\u00b4': "'", '\u02bb': "'", '\u02bc': "'"
    }
    
    for unicode_char, ascii_char in replacements.items():
        text = text.replace(unicode_char, ascii_char)
    
    cleaned = []
    for char in text:
        if ord(char) < 128 or char in ['\n', '\r', '\t']:
            cleaned.append(char)
        else:
            cleaned.append(' ')
    
    return ''.join(cleaned)

def sanitize_code_from_model(raw_text):
    """Sanitize and extract Java code from model output."""
    if raw_text is None:
        return None
    
    text = raw_text.strip()
    text = normalize_unicode_to_ascii(text)
    
    # Handle fenced code blocks
    if "```" in text:
        parts = text.split("```")
        for part in parts:
            if part.lower().startswith("java"):
                text = part[4:].lstrip()
                break
        else:
            # Fallback: find the largest block inside backticks
            candidates = [p for p in parts if len(p.strip()) > 20]
            if candidates:
                text = max(candidates, key=len)
    
    # Remove common LLM artifacts
    llm_artifacts = [
        r'< begin of sentence >', r'<begin of sentence>', r'< end of sentence >',
        r'<end of sentence>', r'<\|begin_of_text\|>', r'<\|end_of_text\|>',
        r'<s>', r'</s>', r'<\|startoftext\|>', r'<\|endoftext\|>',
        r'<\|file_separator\|>', r'<\|code_start\|>', r'<\|code_end\|>'
    ]
    
    for artifact in llm_artifacts:
        text = re.sub(artifact, '', text, flags=re.IGNORECASE)
    
    # Clean up specific System.out artifacts if they leaked into code
    # e.g., System.out< begin of sentence >println
    text = re.sub(r'System\.out\s*<[^>]+>\s*', 'System.out.', text)
    
    # Basic validation
    if "class Main" not in text:
        # Try to wrap it if it looks like code but missing class
        if "public static void main" in text:
             text = "public class Main {\n" + text + "\n}"
        else:
            return None
    
    return text.strip()

print(f"{Fore.GREEN}✓ Helper functions defined (Updated with improved sanitization)")

In [None]:
"""
Prompt templates for clone generation
"""
TYPE1_PROMPT_TEMPLATE = """You are a Java code formatter. Transform this Java code by ONLY changing formatting while preserving all semantics.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

RULES:
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. ONLY change formatting: whitespace, indentation, line breaks, comments
4. MUST preserve all identifiers, literals, and code structure
5. DO NOT rename variables, methods, or classes
6. DO NOT change any literals or expressions
7. DO NOT add, remove, or modify any statements
8. DO NOT change control flow structure
9. Output raw Java code ONLY (no markdown, no explanation)

Original Code:
<<<CODE_PLACEHOLDER>>>

Formatted Code:"""

TYPE2_PROMPT_TEMPLATE = """You are a Java refactoring assistant. Transform this code by renaming identifiers and changing literals while preserving exact behavior.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

RULES:
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. CAN rename variables, parameters, and method names (EXCEPT main method)
4. CAN change literals (e.g., 10→0xA, true→(1==1), "test"→"TEST".toLowerCase())
5. MUST preserve exact control flow and structure
6. DO NOT add, remove, or reorder any statements
7. DO NOT change the algorithmic logic or approach
8. DO NOT modify control flow patterns (if/else, loops, etc.)
9. Structure and statement order MUST remain identical
10. Output raw Java code ONLY (no markdown, no explanation)

Original Code:
<<<CODE_PLACEHOLDER>>>

Refactored Code:"""

TYPE3_PROMPT_TEMPLATE = """You are a Java code mutator. Transform this code with SIGNIFICANT statement-level modifications while preserving exact program behavior.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

**TYPE-3 CLONE REQUIREMENTS - YOU MUST DO AT LEAST 3 OF THESE:**
1. Replace for loops with while loops (or vice versa)
2. Add temporary variables to break up complex expressions: `result = a + b + c` → `temp = a + b; result = temp + c`
3. Add dead code: unused variables, unreachable statements after return/break
4. Reorder independent statements (declarations, assignments that don't depend on each other)
5. Replace if-else with ternary operators (or vice versa): `if(x>0) y=1; else y=0;` → `y = (x>0) ? 1 : 0;`
6. Add redundant calculations: `x = 5` → `x = 3 + 2` or `x = 10/2`
7. Extract inline calculations into separate statements
8. Add extra variable assignments that don't change behavior
9. Change loop increment styles: `i++` → `i = i + 1` → `i += 1`
10. Add null checks or bounds checks that are always true/false

**CONCRETE EXAMPLES OF REQUIRED CHANGES:**

Example 1 - Loop Conversion:
BEFORE: `for(int i=0; i<n; i++) { sum += arr[i]; }`
AFTER: `int i = 0; while(i < n) { sum += arr[i]; i = i + 1; }`

Example 2 - Expression Breakdown:
BEFORE: `int result = (a + b) * (c - d);`
AFTER: `int temp1 = a + b; int temp2 = c - d; int result = temp1 * temp2;`

Example 3 - Dead Code Addition:
BEFORE: `return result;`
AFTER: `int unused = 42; return result; System.out.println("never reached");`

**FORBIDDEN (these make Type-1 clones, NOT Type-3):**
- Only changing whitespace/formatting
- Only renaming variables
- Only changing comments
- Only changing literal values without structural impact

**REQUIRED STRUCTURE:**
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. MUST preserve exact input/output behavior
4. MUST have noticeable structural differences from original
5. Output raw Java code ONLY (no markdown, no explanation)

Original Code:
<<<CODE_PLACEHOLDER>>>

Structurally Modified Code:"""

TYPE4_PROMPT_TEMPLATE = """You are an expert Java programmer. Rewrite this code using a completely different algorithm while maintaining identical observable behavior.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

RULES:
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. MUST preserve exact input format and parsing
4. MUST preserve exact output format and content
5. MUST have identical behavior for ALL possible inputs
6. CAN use completely different algorithms, data structures, approaches
7. CAN restructure the entire program logic
8. CAN use different computational strategies
9. Structure and implementation MAY be completely different
10. Observable input/output behavior MUST be identical
11. Output raw Java code ONLY (no markdown, no explanation)

Original Code:
<<<CODE_PLACEHOLDER>>>

Rewritten Code:"""

EASY_NONCLONE_PROMPT_TEMPLATE = """You are a Java programmer. Create a simple, different Java program that solves a basic algorithmic problem.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

RULES:
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. Create a program for a COMPLETELY DIFFERENT problem domain
4. DO NOT reuse any variable names from the reference code
5. DO NOT use similar control-flow patterns
6. DO NOT use similar data structures
7. Must solve a clearly different algorithmic problem
8. Use basic concepts: simple loops, arrays, basic arithmetic
9. Must be functionally complete and compilable
10. Different problem goal and output meaning required
11. Output raw Java code ONLY (no markdown, no explanation)

Reference Code (CREATE SOMETHING COMPLETELY DIFFERENT):
<<<CODE_PLACEHOLDER>>>

New Different Program:"""

HARD_NONCLONE_PROMPT_TEMPLATE = """You are an expert Java programmer. Create a sophisticated Java program that has similar structure but different semantics from the reference code.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

RULES:
1. MUST have class name as "Main" (CRITICAL for compilation)
2. MUST have public static void main(String[] args) method
3. MUST have similar control flow patterns (similar if/else, loop structures)
4. MUST have similar program skeleton and structure
5. MUST solve a DIFFERENT semantic problem with DIFFERENT output meaning
6. MUST NOT have behavioral equivalence with the reference code
7. Use advanced concepts: collections, recursion, object-oriented design
8. High structural similarity but different algorithmic goal required
9. Must be functionally complete and compilable
10. Different problem domain but similar complexity
11. Output raw Java code ONLY (no markdown, no explanation)

Reference Code (CREATE SIMILAR STRUCTURE, DIFFERENT SEMANTICS):
<<<CODE_PLACEHOLDER>>>

New Structurally Similar Program:"""


print(f"{Fore.GREEN}✓ Prompt templates defined (including non-clone templates)")
print(f"{Fore.CYAN}  Clone templates: Type-1, Type-2, Type-3, Type-4")
print(f"{Fore.CYAN}  Non-clone templates: Easy, Hard")

In [None]:
"""
CodeNet data loading functions
"""

# Load the Project_CodeNet_Java.parquet as a DataFrame
PARQUET_PATH = Path("Project_CodeNet_Java.parquet")
print(f"{Fore.CYAN}Loading Project_CodeNet_Java.parquet ...")
codenet_df = pd.read_parquet(PARQUET_PATH)
print(f"{Fore.GREEN}✓ Loaded {len(codenet_df)} records from Project_CodeNet_Java.parquet")
print("Columns in Project_CodeNet_Java.parquet:", codenet_df.columns.tolist())  # DEBUG: Show columns for fixing 'code' KeyError

# Refactored data access functions

def list_problems():
    """List all unique problem IDs in the Parquet dataset."""
    return sorted(codenet_df['problem_id'].unique())

def choose_seed(problem_id):
    """Select a seed Java submission from the Parquet DataFrame."""
    df = codenet_df[codenet_df['problem_id'] == problem_id]
    if df.empty:
        return None, None
    # Filter for Java and Accepted submissions if columns exist
    if 'language' in df.columns:
        df = df[df['language'].str.lower() == 'java']
    if 'status' in df.columns:
        df = df[df['status'].str.lower() == 'accepted']
    if df.empty:
        return None, None
    # Use 'source_code' as the code column
    df = df.assign(_code_len=df['source_code'].str.len()).sort_values('_code_len')
    row = df.iloc[0]
    code = row['source_code'] if 'source_code' in row else None
    submission_id = row['submission_id'] if 'submission_id' in row else None
    if code and len(code) <= 10240:
        return code, submission_id
    return None, None

def load_testcases(problem_id):
    """Load input/output testcases for a problem from the parquet file."""
    # The parquet file has 'inputs' and 'outputs' columns for each submission
    # We'll select the first Java/Accepted submission for the problem with non-null inputs/outputs
    df = codenet_df[codenet_df['problem_id'] == problem_id]
    if 'language' in df.columns:
        df = df[df['language'].str.lower() == 'java']
    if 'status' in df.columns:
        df = df[df['status'].str.lower() == 'accepted']
    # Only keep rows with non-null inputs and outputs
    df = df[df['inputs'].notnull() & df['outputs'].notnull()]
    if df.empty:
        return []
    # Use the first available testcase
    row = df.iloc[0]
    input_text = row['inputs']
    output_text = row['outputs']
    # If these are lists/strings of testcases, split if needed
    if isinstance(input_text, str) and '\n' in input_text and isinstance(output_text, str) and '\n' in output_text:
        # Assume each line is a testcase
        input_lines = input_text.strip().split('\n')
        output_lines = output_text.strip().split('\n')
        # Pair up lines (if counts match)
        if len(input_lines) == len(output_lines):
            return list(zip(input_lines, output_lines))
    # Otherwise, just return as a single testcase
    return [(input_text, output_text)]

print(f"{Fore.GREEN}✓ CodeNet loading functions defined")

In [None]:
def log(message, color=Fore.WHITE):
    """Simple logger function."""
    try:
        print(f"{color}{message}{Style.RESET_ALL}")
    except Exception:
        print(message)

# Cell: Hugging Face Transformers model loading and prompt formatting
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Use Qwen2.5-Coder-7B from Hugging Face
HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-7B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_NAME,
    device_map="auto",
    dtype=torch.float16,
    quantization_config=bnb_config,
    trust_remote_code=True
)

def get_system_prompt():
    return (
        "You are a highly accurate, concise, and reliable Java code assistant. "
        "You never hallucinate, always follow instructions, and only output valid Java code. "
        "If you are unsure, say so."
    )

def format_prompt(user_instruction, code=None):
    prompt = get_system_prompt() + "\n\n" + user_instruction.strip()
    if code is not None:
        prompt += f"\n\nInput Code:\n{code.strip()}"
    return prompt

def ask_model_transformers(prompt, max_tokens=1500, temperature=0.1):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    generated = tokenizer.decode(output[0], skip_special_tokens=True)
    # Remove the prompt from the output if present
    if generated.startswith(prompt):
        generated = generated[len(prompt):].strip()
    return generated

def quick_check_code_quality(code_str):
    """Check code quality without compilation."""
    if not code_str or len(code_str) < 50:
        return False, "Code too short"
    
    if "class Main" not in code_str:
        return False, "Missing 'class Main'"
    
    if "main(" not in code_str:
        return False, "Missing main method"
    
    # Check for suspicious patterns
    suspicious = [
        "TODO:", "FIXME:", "[Your code here]", "// ... rest of",
        "// Original code", "// Explanation:", "Note that", 
        "< begin of sentence >", "<begin of sentence>",
        "< end of sentence >", "<end of sentence>"
    ]
    
    for pattern in suspicious:
        if pattern in code_str:
            return False, f"Contains suspicious pattern: {pattern}"
    
    # Check basic syntax balance - Relaxed to avoid false positives on generics or bitwise ops
    if code_str.count('{') != code_str.count('}'):
        return False, "Unbalanced braces"
    
    if code_str.count('(') != code_str.count(')'):
        return False, "Unbalanced parentheses"
    
    # Check for incomplete statements
    trimmed = code_str.strip()
    if trimmed and trimmed[-1] not in ['}', ';', '*', '/']:
        return False, "Code appears incomplete"
    
    return True, "OK"

# [REMAINING VALIDATION FUNCTIONS KEPT AS IS]
def compile_java(temp_dir):
    """Compile Main.java in temp_dir."""
    java_file = Path(temp_dir) / "Main.java"
    
    if not java_file.exists():
        return False, "Main.java not found"
    
    try:
        result = subprocess.run(
            ["javac", str(java_file)],
            cwd=temp_dir,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=30,
            check=False
        )
        
        if result.returncode != 0:
            error = result.stderr.decode('utf-8', errors='ignore')
            return False, f"Compilation error: {error[:500]}"
        
        return True, None
        
    except subprocess.TimeoutExpired:
        return False, "Compilation timeout"
    except FileNotFoundError:
        return False, "javac not found. Please install JDK."
    except Exception as e:
        return False, f"Compilation exception: {str(e)}"

def run_java_with_input(temp_dir, input_str, timeout=3):
    """Run compiled Java program with given input."""
    try:
        result = subprocess.run(
            ["java", "Main"],
            cwd=temp_dir,
            input=input_str.encode('utf-8'),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout,
            check=False
        )
        
        if result.returncode != 0:
            error = result.stderr.decode('utf-8', errors='ignore')
            return None, f"Runtime error: {error[:500]}"
        
        output = result.stdout.decode('utf-8', errors='ignore')
        return output, None
        
    except subprocess.TimeoutExpired:
        return None, "Execution timeout"
    except Exception as e:
        return None, f"Execution exception: {str(e)}"

def normalize_output(text):
    """Normalize output text."""
    lines = text.strip().split('\n')
    return '\n'.join(line.rstrip() for line in lines)

def validate_java(code_str, problem_id):
    """Validate Java code by compiling and running against testcases."""
    testcases = load_testcases(problem_id)
    
    if not testcases:
        return "no_tests"
    
    with tempfile.TemporaryDirectory() as temp_dir:
        java_file = Path(temp_dir) / "Main.java"
        
        try:
            java_file.write_text(code_str, encoding='utf-8')
        except Exception:
            return "compile_error"
        
        compile_success, compile_error = compile_java(temp_dir)
        
        if not compile_success:
            return f"compile_error: {compile_error}" # Enhanced return to include error message
        
        for idx, (input_text, expected_output) in enumerate(testcases):
            output, error = run_java_with_input(temp_dir, input_text, timeout=TIMEOUT_SECONDS)
            
            if error:
                if "timeout" in error.lower():
                    return "timeout"
                else:
                    return f"runtime_error: {error}" # Enhanced return
            
            norm_output = normalize_output(output)
            norm_expected = normalize_output(expected_output)
            
            if norm_output != norm_expected:
                return "wrong_answer"
        
        return "passed"

print(f"{Fore.GREEN}✓ Java validation functions defined (Updated checks and error reporting)")


## 2. Model Selection Logic
Define the hybrid model selection strategy based on clone type complexity.

In [None]:
"""
Remove Ollama connectivity check and related code
"""

# Check Ollama connection
print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}Checking Ollama Connection...")
print(f"{Fore.CYAN}{'='*60}\n")

# is_connected, available_models = check_ollama_connection()

# if is_connected:
#     print(f"{Fore.GREEN}✓ Ollama is running and accessible")
#     print(f"\n{Fore.CYAN}Available models:")
#     for model in available_models:
#         print(f"  • {model}")

#     # Check if required models are available
#     required_models = [EASY_MODEL, HARD_MODEL]
#     missing_models = []

#     for req_model in set(required_models):  # Use set to avoid duplicates
#         if not any(req_model in available for available in available_models):
#             missing_models.append(req_model)

#     if missing_models:
#         print(f"\n{Fore.YELLOW}⚠ Warning: Required models not found:")
#         for model in missing_models:
#             print(f"  • {model}")
#         print(f"\n{Fore.CYAN}To download missing models, run:")
#         for model in missing_models:
#             print(f"  ollama pull {model}")
#     else:
#         print(f"\n{Fore.GREEN}✓ All required models are available")
# else:
#     print(f"{Fore.RED}✗ Cannot connect to Ollama")
#     print(f"\n{Fore.YELLOW}Please ensure Ollama is running:")
#     print(f"  1. Open a new terminal")
#     print(f"  2. Run: ollama serve")
#     print(f"  3. Then re-run this notebook")
#     print(f"\n{Fore.CYAN}If you don't have Ollama installed:")
#     print(f"  Visit: https://ollama.ai/download")

#     raise RuntimeError("Ollama is not running. Please start Ollama and try again.")

print(f"\n{Fore.GREEN}✓ Pre-flight checks complete!")


In [None]:
REPAIR_PROMPT_TEMPLATE = """You are a Java code repair assistant. The following Java code has errors. Fix the validation errors and output the corrected code.

**CRITICAL:** Your output must be raw Java code ONLY. Do not include any markdown, explanations, or special tokens.

Rules:
1. MUST have class name as "Main"
2. Fix SPECIFICALLY the error reported below
3. Preserve the original logic as much as possible
4. Output raw Java code ONLY

Original Code:
<<<CODE_PLACEHOLDER>>>

Validation Error:
<<<ERROR_PLACEHOLDER>>>

Fixed Code:"""

def get_model_for_clone_type(clone_type):
    """
    Return the appropriate model based on clone type complexity.
    
    Easy types (Type-1, Type-2): Use qwen2.5-coder:7b
    Hard types (Type-3, Type-4): Use qwen2.5-coder:7b
    """
    if clone_type in ['type1', 'type2']:
        return EASY_MODEL
    elif clone_type in ['type3', 'type4']:
        return HARD_MODEL
    else:
        # Default to hard model for unknown types
        return HARD_MODEL

def generate_with_repair(prompt, model_name, problem_id=None, max_retries=3):
    """
    Generate code with a repair loop. 
    If validation fails, ask the model to fix it using the error message.
    Handles API timeouts with retry logic.
    """
    # 1. Initial Generation
    temperatures = [0.1, 0.3, 0.5]
    best_candidate = None
    for attempt in range(2): # 2 attempts at initial generation
        temp = temperatures[min(attempt, len(temperatures) - 1)]
        raw, error = ask_model_transformers(prompt, max_tokens=1500, temperature=temp)
        if error:
            log(f"[MODEL ERROR] {error}", Fore.RED)
            if "timeout" in str(error).lower():
                log(f"[MODEL ERROR] Ollama API timeout after retries. Skipping this attempt.", Fore.RED)
            continue
        code = sanitize_code_from_model(raw)
        if not code:
            continue
        # Quick check
        valid_syntax, reason = quick_check_code_quality(code)
        if not valid_syntax:
            print(f"  {Fore.YELLOW}Syntax check failed: {reason}")
            error_msg = f"Syntax error: {reason}"
        else:
            if problem_id:
                result = validate_java(code, problem_id)
                if result == "passed":
                    return code # Success!
                else:
                    error_msg = result
            else:
                return code # No problem ID to validate against, so return code
        print(f"  {Fore.YELLOW}Attempting repair for error: {error_msg}")
        current_code = code
        for repair_attempt in range(2): # 2 repair attempts
            repair_prompt = REPAIR_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", current_code).replace("<<<ERROR_PLACEHOLDER>>>", str(error_msg))
            repair_model = HARD_MODEL
            raw_repair, err = ask_model_transformers(repair_prompt, max_tokens=1500, temperature=0.1)
            if err:
                log(f"[MODEL ERROR] {err}", Fore.RED)
                if "timeout" in str(err).lower():
                    log(f"[MODEL ERROR] Ollama API timeout during repair after retries. Skipping this repair attempt.", Fore.RED)
                break
            repaired_code = sanitize_code_from_model(raw_repair)
            if not repaired_code:
                continue
            valid_syntax, reason = quick_check_code_quality(repaired_code)
            if not valid_syntax:
                error_msg = f"Syntax error after repair: {reason}"
                current_code = repaired_code
                continue
            if problem_id:
                result = validate_java(repaired_code, problem_id)
                if result == "passed":
                    print(f"  {Fore.GREEN}Repair successful!")
                    return repaired_code
                else:
                    error_msg = result
                    current_code = repaired_code
            else:
                return repaired_code
    return None

def generate_clone_v2(code, clone_type, problem_id):
    if clone_type == 'type1':
        prompt = TYPE1_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    elif clone_type == 'type2':
        prompt = TYPE2_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    elif clone_type == 'type3':
        prompt = TYPE3_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    elif clone_type == 'type4':
        prompt = TYPE4_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    else:
        return None
        
    model = get_model_for_clone_type(clone_type)

    # For Type-3 clones, we need to validate that it's actually Type-3 and not Type-1
    if clone_type == 'type3':
        # Try multiple times to get a proper Type-3 clone
        for attempt in range(5):  # 5 attempts for Type-3
            generated_code = generate_with_repair(prompt, model, problem_id)

            if not generated_code:
                continue

            # Validate that this is actually a Type-3 clone
            is_type3, reason = validate_type3_clone(code, generated_code)

            if is_type3:
                print(f"  {Fore.GREEN}✓ Type-3 validation passed: {reason}")
                return generated_code
            else:
                print(f"  {Fore.YELLOW}⚠ Type-3 validation failed (attempt {attempt + 1}): {reason}")
                # Try again with modified prompt to be more explicit
                if attempt < 4:  # Don't modify on last attempt
                    enhanced_prompt = prompt + f"\n\nIMPORTANT: The previous attempt was rejected because: {reason}\nMake MORE SIGNIFICANT structural changes. Remember: You MUST do at least 3 major modifications from the requirements list."
                    prompt = enhanced_prompt

        # If all attempts failed, return None
        print(f"  {Fore.RED}✗ Failed to generate proper Type-3 clone after 5 attempts")
        return None
    else:
        # For other types, use normal generation
        return generate_with_repair(prompt, model, problem_id)

def get_model_for_nonclone_type(nonclone_type):
    if nonclone_type == 'easy':
        return EASY_MODEL
    elif nonclone_type == 'hard':
        return HARD_MODEL
    else:
        return HARD_MODEL

def generate_nonclone_v2(code, nonclone_type, problem_id=None):
    if nonclone_type == 'easy':
        prompt = EASY_NONCLONE_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    elif nonclone_type == 'hard':
        prompt = HARD_NONCLONE_PROMPT_TEMPLATE.replace("<<<CODE_PLACEHOLDER>>>", code)
    else:
        return None
        
    model = get_model_for_nonclone_type(nonclone_type)
    # For non-clones, we just check compilation/sanity, we can't check 'passed' against original tests
    # because the problem is DIFFERENT. So passing problem_id might be misleading if used for test validation.
    # However, generate_with_repair uses problem_id to run tests.
    # For non-clones, we should probably pass None for problem_id to skip testcase validation inside the repair loop,
    # OR update generate_with_repair to handle 'compile_only' mode.
    
    return generate_with_repair(prompt, model, problem_id=None) 

print(f"{Fore.GREEN}✓ Model selection logic & Repair loop defined")

In [None]:
def generate_clones_for_types(clone_types, output_path, target_per_type):
    import json
    # Resume logic: collect already generated (problem_id, clone_type) pairs
    existing_pairs = set()
    if output_path.exists():
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                for line in f:
                    try:
                        record = json.loads(line)
                        pid = record.get('problem_id')
                        ctype = record.get('clone_type')
                        if pid and ctype:
                            existing_pairs.add((pid, ctype))
                    except Exception:
                        continue
        except Exception:
            pass
    # Initialize dataset writer
    dataset_writer = jsonlines.open(output_path, mode='a', flush=True)  # append mode

    # Load problems
    problems = list_problems()
    if not problems:
        print(f"{Fore.RED}✗ No problems found in CodeNet directory")
        return None
    
    print(f"Found {len(problems)} problems in CodeNet")
    
    problems_to_process = problems[:MAX_PROBLEMS] if MAX_PROBLEMS else problems
    
    clone_counters = {ct: 0 for ct in clone_types}
    # Count already generated clones for each type
    for (pid, ctype) in existing_pairs:
        if ctype in clone_counters:
            clone_counters[ctype] += 1

    stats = {
        'no_seed': 0,
        'seed_failed': 0,
        'failed': 0
    }
    
    current_problem_idx = 0
    problems_processed = 0
    
    # Main generation loop
    total_target = target_per_type * len(clone_types)
    with tqdm(total=total_target, desc=f"Generating {', '.join(clone_types)}") as pbar:
        # Set progress bar to already completed
        completed = sum(clone_counters.values())
        if completed > 0:
            pbar.update(completed)
        while any(clone_counters[ct] < target_per_type for ct in clone_types):
            # Check if we've exhausted all problems and need to cycle through again
            if current_problem_idx >= len(problems_to_process):
                current_problem_idx = 0
                problems_processed += 1
                print(f"\n{Fore.YELLOW}Completed cycle {problems_processed}, cycling through problems again...")
                
                # Safety check to prevent infinite loop
                if problems_processed >= MAX_CYCLES:
                    print(f"\n{Fore.RED}Reached maximum cycles limit ({MAX_CYCLES}). Stopping generation.")
                    break
            
            problem_id = problems_to_process[current_problem_idx]
            current_problem_idx += 1
            
            try:
                # Load seed code
                seed_code, sub_id = choose_seed(problem_id)
                
                if not seed_code:
                    stats['no_seed'] += 1
                    continue
                
                # Load testcases
                test_cases = load_testcases(problem_id)
                if not test_cases:
                    stats['no_seed'] += 1
                    continue
                
                # Validate seed
                seed_result = validate_java(seed_code, problem_id)
                if isinstance(seed_result, str) and seed_result.startswith("compile_error"):
                     seed_result = "compile_error" # Normalize for check
                
                if seed_result != "passed":
                    stats['seed_failed'] += 1
                    continue
                
                # Generate clones for this problem
                clones_generated_this_problem = 0
                
                # Generate each type of clone
                for clone_type in clone_types:
                    if clone_counters[clone_type] >= target_per_type:
                        continue
                    
                    if clones_generated_this_problem >= MAX_CLONES_PER_PROBLEM:
                        break
                    
                    # Skip Type-4 for complex files
                    if clone_type == 'type4' and len(seed_code.split('\n')) >= 75:
                        continue
                    
                    # SKIP if already generated for this (problem_id, clone_type)
                    if (problem_id, clone_type) in existing_pairs:
                        continue

                    # Generate clone
                    try:
                        # UPDATED CALL: Pass problem_id for repair loop
                        generated_code = generate_clone_v2(seed_code, clone_type, problem_id)
                        
                        if not generated_code:
                            stats['failed'] += 1
                            continue
                        
                        # Already validated inside generate_clone_v2 (mostly), but double check if returned
                        is_valid, reason = quick_check_code_quality(generated_code)
                        if not is_valid:
                            stats['failed'] += 1
                            continue
                        
                        result = validate_java(generated_code, problem_id)
                        
                        if result == "passed":
                            pair_id = f"{problem_id}_{clone_type}_{uuid.uuid4().hex[:8]}"
                            
                            record = {
                                'id': pair_id,
                                'code_1': seed_code,
                                'code_2': generated_code,
                                'label': "clone",
                                'clone_type': clone_type,
                                'language': 'Java',
                                'problem_id': problem_id,
                                'generator': get_model_for_clone_type(clone_type),
                                'timestamp': time.time()
                            }
                            
                            dataset_writer.write(record)
                            clone_counters[clone_type] += 1
                            clones_generated_this_problem += 1
                            pbar.update(1)
                            # Add to existing_pairs to prevent duplicate in same run
                            existing_pairs.add((problem_id, clone_type))
                        else:
                            stats['failed'] += 1
                    
                    except Exception as e:
                        stats['failed'] += 1
                        log(f"[{problem_id}] {clone_type} error: {repr(e)}", Fore.RED)
                
                # Progress update every 50 problems
                if current_problem_idx % 50 == 0:
                    print(f"\n{Fore.CYAN}Progress Update - Problem {current_problem_idx}/{len(problems_to_process)} (Cycle {problems_processed + 1}):")
                    for ct, count in clone_counters.items():
                        status = "✓" if count >= target_per_type else f"{count}/{target_per_type}"
                        print(f"  {ct}: {status}")
                    print()
            
            except Exception as e:
                stats['failed'] += 1
                log(f"[{problem_id}] Unexpected error: {e}", Fore.RED)
                continue
    
    # Close writer
    dataset_writer.close()
    
    return clone_counters, stats

print(f"{Fore.GREEN}✓ Core generation function defined (Updated with repair loop)")

## 3. Generate Easy Clones (Type-1 and Type-2)
Use the fast `codegemma:2b` model to generate Type-1 and Type-2 clones efficiently.

In [None]:
"""
Step 1: Generate Easy Clones (Type-1 and Type-2) using qwen2.5-coder:7b
"""

print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 1: Generating Easy Clones (Type-1, Type-2)")
print(f"{Fore.CYAN}Model: {EASY_MODEL}")
print(f"{Fore.CYAN}Output: {EASY_OUTPUT_PATH}")
print(f"{Fore.CYAN}{'='*60}\n")

start_time = time.time()

easy_counters, easy_stats = generate_clones_for_types(
    clone_types=['type1', 'type2'],
    output_path=EASY_OUTPUT_PATH,
    target_per_type=TARGET_CLONES_PER_TYPE
)

elapsed_time = time.time() - start_time

print(f"\n{Fore.GREEN}{'='*60}")
print(f"{Fore.GREEN}✓ EASY CLONES GENERATION COMPLETE!")
print(f"{Fore.GREEN}{'='*60}")

print(f"\n{Fore.CYAN}Clone Counts:")
total_easy = 0
for clone_type, count in easy_counters.items():
    status = "✓ COMPLETE" if count >= TARGET_CLONES_PER_TYPE else "⚠ INCOMPLETE"
    print(f"  {clone_type}: {count}/{TARGET_CLONES_PER_TYPE} {status}")
    total_easy += count

print(f"\n{Fore.GREEN}Total easy clones: {total_easy}")
print(f"Time taken: {elapsed_time/60:.2f} minutes")
print(f"Dataset saved to: {EASY_OUTPUT_PATH}")

print(f"\n{Fore.YELLOW}Statistics:")
print(f"  No seed found: {easy_stats['no_seed']}")
print(f"  Seed validation failed: {easy_stats['seed_failed']}")
print(f"  Generation/validation failed: {easy_stats['failed']}")

## 4. Generate Hard Clones (Type-3 and Type-4)
Use the more capable `qwen2.5-coder:7b` model to generate complex Type-3 and Type-4 clones.

**Note:** This step will be significantly slower due to the larger model, but it ensures correct semantic transformations.

In [None]:
"""
Step 2: Generate Hard Clones (Type-3 and Type-4) using qwen2.5-coder:7b
"""

def validate_type3_clone(original_code, generated_code):
    """
    Validates if generated_code is a type-3 clone of original_code.
    Type-3 clones are syntactically similar with some modifications (e.g., added/removed statements, reordered code),
    but not identical (not type-1) and not just renamed (not type-2).
    Returns (is_type3: bool, reason: str)
    """
    import re
    from difflib import SequenceMatcher

    def normalize_code(code):
        # Remove comments
        code = re.sub(r"//.*?$|/\*.*?\*/", "", code, flags=re.DOTALL | re.MULTILINE)
        # Remove whitespace and blank lines
        code = '\n'.join([line.strip() for line in code.splitlines() if line.strip()])
        return code

    norm_orig = normalize_code(original_code)
    norm_gen = normalize_code(generated_code)

    # If identical, not type-3
    if norm_orig == norm_gen:
        return False, "Identical code (type-1)"

    # Compute similarity ratio
    ratio = SequenceMatcher(None, norm_orig, norm_gen).ratio()
    if ratio > 0.6:
        return True, f"Syntactically similar (ratio={ratio:.2f})"
    else:
        return False, f"Not similar enough (ratio={ratio:.2f})"


print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 2: Generating Hard Clones (Type-3, Type-4)")
print(f"{Fore.CYAN}Model: {HARD_MODEL}")
print(f"{Fore.CYAN}Output: {HARD_OUTPUT_PATH}")
print(f"{Fore.CYAN}{'='*60}\n")

start_time = time.time()

hard_counters, hard_stats = generate_clones_for_types(
    clone_types=['type3', 'type4'],
    output_path=HARD_OUTPUT_PATH,
    target_per_type=TARGET_CLONES_PER_TYPE
)

elapsed_time = time.time() - start_time

print(f"\n{Fore.GREEN}{'='*60}")
print(f"{Fore.GREEN}✓ HARD CLONES GENERATION COMPLETE!")
print(f"{Fore.GREEN}{'='*60}")

print(f"\n{Fore.CYAN}Clone Counts:")
total_hard = 0
for clone_type, count in hard_counters.items():
    status = "✓ COMPLETE" if count >= TARGET_CLONES_PER_TYPE else "⚠ INCOMPLETE"
    print(f"  {clone_type}: {count}/{TARGET_CLONES_PER_TYPE} {status}")
    total_hard += count

print(f"\n{Fore.GREEN}Total hard clones: {total_hard}")
print(f"Time taken: {elapsed_time/60:.2f} minutes")
print(f"Dataset saved to: {HARD_OUTPUT_PATH}")

print(f"\n{Fore.YELLOW}Statistics:")
print(f"  No seed found: {hard_stats['no_seed']}")
print(f"  Seed validation failed: {hard_stats['seed_failed']}")
print(f"  Generation/validation failed: {hard_stats['failed']}")

## 5. Combine Datasets
Merge the easy and hard clone datasets into a single comprehensive dataset.

In [None]:
"""
Step 3: Combine Easy and Hard Clone Datasets
"""

print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 3: Combining Datasets")
print(f"{Fore.CYAN}{'='*60}\n")

# Read both datasets
easy_records = []
hard_records = []

if EASY_OUTPUT_PATH.exists():
    with jsonlines.open(EASY_OUTPUT_PATH) as reader:
        easy_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(easy_records)} records from {EASY_OUTPUT_PATH}")
else:
    print(f"{Fore.YELLOW}⚠ Easy clones file not found: {EASY_OUTPUT_PATH}")

if HARD_OUTPUT_PATH.exists():
    with jsonlines.open(HARD_OUTPUT_PATH) as reader:
        hard_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(hard_records)} records from {HARD_OUTPUT_PATH}")
else:
    print(f"{Fore.YELLOW}⚠ Hard clones file not found: {HARD_OUTPUT_PATH}")

# Combine records
all_records = easy_records + hard_records

print(f"\n{Fore.CYAN}Total records to write: {len(all_records)}")

# Write combined dataset
with jsonlines.open(COMBINED_OUTPUT_PATH, mode='w') as writer:
    for record in all_records:
        writer.write(record)

print(f"{Fore.GREEN}✓ Combined dataset written to: {COMBINED_OUTPUT_PATH}")

# Count by type
type_counts = {}
for record in all_records:
    clone_type = record.get('clone_type', 'unknown')
    type_counts[clone_type] = type_counts.get(clone_type, 0) + 1

print(f"\n{Fore.CYAN}Distribution by Clone Type:")
for clone_type in ['type1', 'type2', 'type3', 'type4']:
    count = type_counts.get(clone_type, 0)
    print(f"  {clone_type}: {count}")

print(f"\n{Fore.GREEN}{'='*60}")
print(f"{Fore.GREEN}✓ DATASET COMBINATION COMPLETE!")
print(f"{Fore.GREEN}{'='*60}")

## 6. Validate Combined Dataset
Perform validation and quality checks on the final combined dataset.

In [None]:
"""
Step 4: Validate the Combined Dataset
"""

print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 4: Dataset Validation")
print(f"{Fore.CYAN}{'='*60}\n")

# Load combined dataset
combined_records = []
if COMBINED_OUTPUT_PATH.exists():
    with jsonlines.open(COMBINED_OUTPUT_PATH) as reader:
        combined_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(combined_records)} records from {COMBINED_OUTPUT_PATH}")
else:
    print(f"{Fore.RED}✗ Combined dataset not found: {COMBINED_OUTPUT_PATH}")

if combined_records:
    # Validation checks
    print(f"\n{Fore.CYAN}Validation Checks:")
    
    # 1. Check all clone types are present
    type_distribution = {}
    model_distribution = {}
    
    for record in combined_records:
        clone_type = record.get('clone_type', 'unknown')
        model = record.get('generator', 'unknown')
        
        type_distribution[clone_type] = type_distribution.get(clone_type, 0) + 1
        model_distribution[model] = model_distribution.get(model, 0) + 1
    
    print(f"\n{Fore.CYAN}1. Clone Type Distribution:")
    for clone_type in ['type1', 'type2', 'type3', 'type4']:
        count = type_distribution.get(clone_type, 0)
        percentage = (count / len(combined_records) * 100) if combined_records else 0
        status = "✓" if count > 0 else "✗"
        print(f"  {status} {clone_type}: {count} ({percentage:.1f}%)")
    
    print(f"\n{Fore.CYAN}2. Model Distribution:")
    for model, count in model_distribution.items():
        percentage = (count / len(combined_records) * 100) if combined_records else 0
        print(f"  {model}: {count} ({percentage:.1f}%)")
    
    # 3. Verify model assignment correctness
    print(f"\n{Fore.CYAN}3. Model Assignment Verification:")
    correct_assignments = 0
    incorrect_assignments = 0
    
    for record in combined_records:
        clone_type = record.get('clone_type', '')
        model = record.get('generator', '')
        expected_model = get_model_for_clone_type(clone_type)
        
        if model == expected_model:
            correct_assignments += 1
        else:
            incorrect_assignments += 1
    
    if incorrect_assignments == 0:
        print(f"  ✓ All {correct_assignments} records have correct model assignments")
    else:
        print(f"  ⚠ {correct_assignments} correct, {incorrect_assignments} incorrect")
    
    # 4. Check for required fields
    print(f"\n{Fore.CYAN}4. Required Fields Check:")
    required_fields = ['id', 'code_1', 'code_2', 'label', 'clone_type', 'language', 'problem_id', 'generator']
    
    missing_fields = {}
    for record in combined_records:
        for field in required_fields:
            if field not in record or not record[field]:
                missing_fields[field] = missing_fields.get(field, 0) + 1
    
    if not missing_fields:
        print(f"  ✓ All records have all required fields")
    else:
        for field, count in missing_fields.items():
            print(f"  ⚠ {field}: missing in {count} records")
    
    # 5. Summary statistics
    print(f"\n{Fore.GREEN}{'='*60}")
    print(f"{Fore.GREEN}FINAL DATASET SUMMARY")
    print(f"{Fore.GREEN}{'='*60}")
    print(f"{Fore.CYAN}Total Records: {len(combined_records)}")
    print(f"{Fore.CYAN}Dataset Path: {COMBINED_OUTPUT_PATH}")
    print(f"{Fore.CYAN}File Size: {COMBINED_OUTPUT_PATH.stat().st_size / (1024*1024):.2f} MB")
    
    print(f"\n{Fore.CYAN}Clone Type Breakdown:")
    for clone_type in ['type1', 'type2', 'type3', 'type4']:
        count = type_distribution.get(clone_type, 0)
        print(f"  {clone_type}: {count}/{TARGET_CLONES_PER_TYPE}")
    
    print(f"\n{Fore.GREEN}✓ Validation Complete!")
else:
    print(f"{Fore.RED}✗ No records found for validation")

## 7. Generate Easy Non-Clones
Use the fast `codegemma:2b` model to generate easy non-clones with simple algorithmic differences.

In [None]:
def generate_nonclones_for_types(nonclone_types, output_path, target_per_type):
    # Initialize dataset writer
    dataset_writer = jsonlines.open(output_path, mode='w', flush=True)

    # Load problems
    problems = list_problems()
    if not problems:
        print(f"{Fore.RED}✗ No problems found in CodeNet directory")
        return None

    print(f"Found {len(problems)} problems in CodeNet")

    problems_to_process = problems[:MAX_PROBLEMS] if MAX_PROBLEMS else problems

    # Non-clone counters
    nonclone_counters = {nt: 0 for nt in nonclone_types}

    stats = {
        'failed': 0,
        'seed_failed': 0,
        'no_seed': 0,
        'identical_skipped': 0
    }

    current_problem_idx = 0
    problems_processed = 0

    # Main generation loop
    total_target = target_per_type * len(nonclone_types)
    with tqdm(total=total_target, desc=f"Generating {', '.join(nonclone_types)} non-clones") as pbar:
        while any(nonclone_counters[nt] < target_per_type for nt in nonclone_types):
            # Check if we've exhausted all problems and need to cycle through again
            if current_problem_idx >= len(problems_to_process):
                current_problem_idx = 0
                problems_processed += 1
                print(f"\n{Fore.YELLOW}Completed cycle {problems_processed}, cycling through problems again...")

                # Safety check to prevent infinite loop
                if problems_processed >= MAX_CYCLES:
                    print(f"\n{Fore.RED}Reached maximum cycles limit ({MAX_CYCLES}). Stopping generation.")
                    break

            problem_id = problems_to_process[current_problem_idx]
            current_problem_idx += 1

            try:
                # Load seed code
                seed_code, sub_id = choose_seed(problem_id)

                if not seed_code:
                    stats['no_seed'] += 1
                    continue

                # Load testcases (for validation of generated code)
                test_cases = load_testcases(problem_id)
                if not test_cases:
                    stats['no_seed'] += 1
                    continue

                # Validate seed
                seed_result = validate_java(seed_code, problem_id)
                if isinstance(seed_result, str) and seed_result.startswith("compile_error"):
                     seed_result = "compile_error"
                
                if seed_result != "passed":
                    stats['seed_failed'] += 1
                    continue

                # Generate non-clones for this problem
                nonclones_generated_this_problem = 0

                # Generate each type of non-clone
                for nonclone_type in nonclone_types:
                    if nonclone_counters[nonclone_type] >= target_per_type:
                        continue

                    if nonclones_generated_this_problem >= MAX_CLONES_PER_PROBLEM:
                        break

                    # Generate non-clone
                    try:
                        # UPDATED: Use v2 with None for problem_id (syntax repair only)
                        generated_code = generate_nonclone_v2(seed_code, nonclone_type, problem_id=None)

                        if not generated_code:
                            stats['failed'] += 1
                            continue

                        # Quick quality check
                        is_valid, reason = quick_check_code_quality(generated_code)
                        if not is_valid:
                            stats['failed'] += 1
                            continue

                        # For non-clones, we just need to check if it compiles and runs
                        # (it doesn't need to pass the original test cases)
                        with tempfile.TemporaryDirectory() as temp_dir:
                            java_file = Path(temp_dir) / "Main.java"

                            try:
                                java_file.write_text(generated_code, encoding='utf-8')
                                compile_success, compile_error = compile_java(temp_dir)

                                if compile_success:
                                    # Try to run it with empty input to see if it executes
                                    output, error = run_java_with_input(temp_dir, "", timeout=5)

                                    # As long as it compiles and doesn't crash, it's valid
                                    if error is None or "timeout" not in error.lower():
                                        pair_id = f"{problem_id}_nonclone_{nonclone_type}_{uuid.uuid4().hex[:8]}"

                                        record = {
                                            'id': pair_id,
                                            'code_1': seed_code,
                                            'code_2': generated_code,
                                            'label': "non-clone",
                                            'clone_type': f"nonclone_{nonclone_type}",
                                            'language': 'Java',
                                            'problem_id': problem_id,
                                            'generator': get_model_for_nonclone_type(nonclone_type),
                                            'timestamp': time.time()
                                        }

                                        dataset_writer.write(record)
                                        nonclone_counters[nonclone_type] += 1
                                        nonclones_generated_this_problem += 1
                                        pbar.update(1)
                                    else:
                                        stats['failed'] += 1
                                else:
                                    stats['failed'] += 1
                            except Exception:
                                stats['failed'] += 1

                    except Exception as e:
                        stats['failed'] += 1
                        # print(f"[{problem_id}] {nonclone_type} non-clone error: {repr(e)}") # debug
                        # log(f"[{problem_id}] {nonclone_type} non-clone error: {repr(e)}", Fore.RED)

                # Progress update every 50 problems
                if current_problem_idx % 50 == 0:
                    print(f"\n{Fore.CYAN}Progress Update - Problem {current_problem_idx}/{len(problems_to_process)} (Cycle {problems_processed + 1}):")
                    for nt, count in nonclone_counters.items():
                        status = "✓" if count >= target_per_type else f"{count}/{target_per_type}"
                        print(f"  {nt} non-clones: {status}")
                    print()

            except Exception as e:
                stats['failed'] += 1
                # log(f"[{problem_id}] Unexpected error: {e}", Fore.RED)
                continue

    # Close writer
    dataset_writer.close()

    return nonclone_counters, stats

print(f"{Fore.GREEN}✓ Non-clone generation function defined")

## 8. Generate Hard Non-Clones
Use the more capable `deepseek-coder:6.7b` model to generate hard non-clones with different problem domains.

In [None]:
"""
Step 6: Generate Hard Non-Clones using deepseek-coder:6.7b
"""

print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 6: Generating Hard Non-Clones")
print(f"{Fore.CYAN}Model: {HARD_MODEL}")
print(f"{Fore.CYAN}Output: {HARD_NONCLONES_OUTPUT_PATH}")
print(f"{Fore.CYAN}{'='*60}\n")

start_time = time.time()

hard_nonclone_counters, hard_nonclone_stats = generate_nonclones_for_types(
    nonclone_types=['hard'],
    output_path=HARD_NONCLONES_OUTPUT_PATH,
    target_per_type=TARGET_NONCLONES_HARD
)

elapsed_time = time.time() - start_time

print(f"\n{Fore.GREEN}{'='*60}")
print(f"{Fore.GREEN}✓ HARD NON-CLONES GENERATION COMPLETE!")
print(f"{Fore.GREEN}{'='*60}")

print(f"\n{Fore.CYAN}Non-Clone Counts:")
total_hard_nonclones = 0
for nonclone_type, count in hard_nonclone_counters.items():
    status = "✓ COMPLETE" if count >= TARGET_NONCLONES_HARD else "⚠ INCOMPLETE"
    print(f"  {nonclone_type}: {count}/{TARGET_NONCLONES_HARD} {status}")
    total_hard_nonclones += count

print(f"\n{Fore.GREEN}Total hard non-clones: {total_hard_nonclones}")
print(f"Time taken: {elapsed_time/60:.2f} minutes")
print(f"Dataset saved to: {HARD_NONCLONES_OUTPUT_PATH}")

print(f"\n{Fore.YELLOW}Statistics:")
print(f"  No seed found: {hard_nonclone_stats['no_seed']}")
print(f"  Seed validation failed: {hard_nonclone_stats['seed_failed']}")
print(f"  Generation/validation failed: {hard_nonclone_stats['failed']}")

## 9. Create Complete Dataset
Combine all clones and non-clones into a single comprehensive dataset for training.

In [None]:
"""
Step 7: Create Complete Dataset (Clones + Non-Clones)
"""

print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}STEP 7: Creating Complete Dataset")
print(f"{Fore.CYAN}{'='*60}\n")

# Read all datasets
clone_records = []
easy_nonclone_records = []
hard_nonclone_records = []

if COMBINED_OUTPUT_PATH.exists():
    with jsonlines.open(COMBINED_OUTPUT_PATH) as reader:
        clone_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(clone_records)} clone records from {COMBINED_OUTPUT_PATH}")
else:
    print(f"{Fore.YELLOW}⚠ Clone dataset file not found: {COMBINED_OUTPUT_PATH}")

if EASY_NONCLONES_OUTPUT_PATH.exists():
    with jsonlines.open(EASY_NONCLONES_OUTPUT_PATH) as reader:
        easy_nonclone_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(easy_nonclone_records)} easy non-clone records from {EASY_NONCLONES_OUTPUT_PATH}")
else:
    print(f"{Fore.YELLOW}⚠ Easy non-clones file not found: {EASY_NONCLONES_OUTPUT_PATH}")

if HARD_NONCLONES_OUTPUT_PATH.exists():
    with jsonlines.open(HARD_NONCLONES_OUTPUT_PATH) as reader:
        hard_nonclone_records = list(reader)
    print(f"{Fore.GREEN}✓ Loaded {len(hard_nonclone_records)} hard non-clone records from {HARD_NONCLONES_OUTPUT_PATH}")
else:
    print(f"{Fore.YELLOW}⚠ Hard non-clones file not found: {HARD_NONCLONES_OUTPUT_PATH}")

# Combine all records
all_complete_records = clone_records + easy_nonclone_records + hard_nonclone_records

print(f"\n{Fore.CYAN}Total records to write: {len(all_complete_records)}")
print(f"  Clones: {len(clone_records)}")
print(f"  Easy Non-clones: {len(easy_nonclone_records)}")
print(f"  Hard Non-clones: {len(hard_nonclone_records)}")

# Write complete dataset
with jsonlines.open(FINAL_DATASET_PATH, mode='w') as writer:
    for record in all_complete_records:
        writer.write(record)

print(f"{Fore.GREEN}✓ Complete dataset written to: {FINAL_DATASET_PATH}")

# Count by type and label
label_counts = {}
type_counts = {}
for record in all_complete_records:
    label = record.get('label', 'unknown')
    clone_type = record.get('clone_type', 'unknown')

    label_counts[label] = label_counts.get(label, 0) + 1
    type_counts[clone_type] = type_counts.get(clone_type, 0) + 1

print(f"\n{Fore.CYAN}Distribution by Label:")
for label in ['clone', 'non-clone']:
    count = label_counts.get(label, 0)
    percentage = (count / len(all_complete_records) * 100) if all_complete_records else 0
    print(f"  {label}: {count} ({percentage:.1f}%)")

print(f"\n{Fore.CYAN}Distribution by Type:")
for type_name in ['type1', 'type2', 'type3', 'type4', 'nonclone_easy', 'nonclone_hard']:
    count = type_counts.get(type_name, 0)
    percentage = (count / len(all_complete_records) * 100) if all_complete_records else 0
    print(f"  {type_name}: {count} ({percentage:.1f}%)")

print(f"\n{Fore.GREEN}{'='*60}")
print(f"{Fore.GREEN}✓ COMPLETE DATASET CREATION FINISHED!")
print(f"{Fore.GREEN}{'='*60}")
print(f"{Fore.CYAN}Final Dataset: {FINAL_DATASET_PATH}")
print(f"{Fore.CYAN}Total Records: {len(all_complete_records)}")
print(f"{Fore.CYAN}File Size: {FINAL_DATASET_PATH.stat().st_size / (1024*1024):.2f} MB")
