In [1]:
import sys
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import scipy
from scipy.spatial.distance import euclidean
from scipy.signal import savgol_filter
from dataclasses import dataclass
from enum import Enum

from transformers import pipeline
from sklearn.decomposition import PCA
import javalang
import tree_sitter_java
from tree_sitter import Language, Parser

sys.path.append(os.path.abspath(".."))

from data.dataset_factory import get_dataset_generator
from data.data_generators.sourcecodeplag_dataset_gen import original_plag_triplet_generator
from preprocessing.embedding_chunks import get_ready_to_embed_chunks
from preprocessing.context_chunker import safe_get_ready_to_embed_context_chunks
from preprocessing.mean_pool_chunks import mean_pool_chunks
from preprocessing.block_splitter import deverbose_ast
from visualizer.smoothing import smooth_embeddings, smooth_multiple_embeddings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
anchor = """
public class T5 {
	public static void main(String[] args) {
		System.out.print("Enter an integer: ");
		java.util.Scanner input = new java.util.Scanner(System.in);
		int number = input.nextInt();
		reverse(number);
	}

	public static void reverse(int number) {
		while (number != 0) {
			int remainder = number % 10;
			System.out.print(remainder);
			number = number / 10;
		}

		System.out.println();
	}

}
"""
clone = """import java.util.*;

class method{
	//prog utama
	public static void main(String[] args) 
	{
			System.out.print("Enter an integer: ");
			java.util.Scanner input = new java.util.Scanner(System.in);
		int n = input.nextInt();
		//pamggil method
			beautyReverse(n);
	}
	
	//method reverse
	public static void beautyReverse(int num) 
	{
		while (num != 0)
			{
			int r = num % 10;
				System.out.print(r);
			num = num / 10;
		}
		System.out.println();
	}
}"""
non_clone = """import java.util.Scanner;

public class Main {
    //function for Summary
    public static double sumMajorDiagonal(double[][] mtx) {
        double sum = 0;

        for (int i = 0; i < mtx.length; i++)
            sum += mtx[i][i];
        return sum;
    }

    public static void main(String[] args) {

        double[][] mtx = new double[4][4];
        Scanner s = new Scanner(System.in);
        //input 4*4 matrix data
        System.out.print("Enter a 4 by 4 matrix row by row: ");


        for (int i = 0; i < 4; i++)
            for (int j = 0; j < 4; j++)
                mtx[i][j] = s.nextDouble();


        System.out.print("Sum of the elements in the major diagonal is "+ sumMajorDiagonal(mtx));
    }

}

"""

In [3]:
import tree_sitter_java
from tree_sitter import Language, Parser

JAVA_LANGUAGE = Language(tree_sitter_java.language())
parser = Parser(JAVA_LANGUAGE)

In [4]:
model_name = "microsoft/graphcodebert-base"
pipe = pipeline("feature-extraction", model=model_name)

Loading weights: 100%|██████████| 197/197 [00:00<00:00, 2180.75it/s, Materializing param=encoder.layer.11.output.dense.weight]              
RobertaModel LOAD REPORT from: microsoft/graphcodebert-base
Key                             | Status     | 
--------------------------------+------------+-
lm_head.layer_norm.weight       | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
lm_head.decoder.bias            | UNEXPECTED | 
lm_head.dense.bias              | UNEXPECTED | 
lm_head.decoder.weight          | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
pooler.dense.bias               | MISSING    | 
pooler.dense.weight             | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training 

In [5]:
class ChunkKind(Enum):
    STRAIGHT = "straight"
    CONTROL = "control"

@dataclass
class Chunk:
    text: str
    embedding: None
    ast_depth: int
    kind: ChunkKind

In [6]:
CONTROL_NODES = (
    "if_statement",
    "for_statement",
    "while_statement",
    "do_statement",
    "switch_statement",
    "try_statement",
    "catch_clause",
)

STRAIGHT_NODES = (
    "local_variable_declaration",
    "expression_statement",
    "return_statement",
    "throw_statement",
)

In [7]:
def extract_source_by_position(source_lines, node):
    if not node.position:
        return None

    start_line = node.position.line - 1

    # Heuristic: extract until block end
    # (works well for control-flow)
    lines = source_lines[start_line:start_line + 20]
    return "\n".join(lines)

In [8]:
class JavaChunkExtractor:
    def __init__(self, source_code):
        self.source = source_code
        self.lines = source_code.splitlines()
        self.chunks = []

    def visit(self, node, depth=0):
        if isinstance(node, CONTROL_NODES):
            self._add_chunk(node, depth, ChunkKind.CONTROL)

        elif isinstance(node, STRAIGHT_NODES):
            self._add_chunk(node, depth, ChunkKind.STRAIGHT)

        for _, child in node.filter(javalang.tree.Node):
            self.visit(child, depth + 1)

In [9]:
class TreeSitterJavaChunker:
    def __init__(self, source_code: str):
        self.code = source_code.encode("utf8")
        self.chunks = []

    def extract(self):
        tree = parser.parse(self.code)
        self._visit(tree.root_node, depth=0)
        return self.chunks

    def _visit(self, node, depth):
        # Only consider nodes that correspond to statements
        if node.type in CONTROL_NODES:
            self._add_chunk(node, depth, ChunkKind.CONTROL)
        elif node.type in STRAIGHT_NODES:
            self._add_chunk(node, depth, ChunkKind.STRAIGHT)

        # Recurse to children
        for child in node.children:
            # skip tokens
            if len(child.children) == 0 and child.type in (';', '{', '}', '(', ')', ','):
                continue
            self._visit(child, depth + 1)


    def _add_chunk(self, node, depth, kind):
        text = self.code[node.start_byte : node.end_byte].decode("utf8")

        self.chunks.append(
            Chunk(
                text=text,
                embedding=None,
                ast_depth=depth,
                kind=kind,
            )
        )


In [11]:
def parse_and_chunk_java(code: str):
    chunker = TreeSitterJavaChunker(code)
    chunks = chunker.extract()

    S_chunks = [c for c in chunks if c.kind == ChunkKind.STRAIGHT]
    C_chunks = [c for c in chunks if c.kind == ChunkKind.CONTROL]

    return S_chunks, C_chunks

In [12]:
def embed_chunk_text(text):
    """Embed a single code chunk as a fixed vector"""
    outputs = pipe(text)            # shape: [1, seq_len, hidden_dim]
    arr = np.array(outputs[0])      # shape: [seq_len, hidden_dim]
    return arr.mean(axis=0)         # mean over tokens → shape: (hidden_dim,)


def embed_chunks(S, C):
    for chunk in S + C:
        text = chunk.text.strip()
        if text == "":
            continue
        chunk.embedding = embed_chunk_text(text)


In [13]:
import numpy as np

def combine_chunks(chunks):
    if len(chunks) == 0:
        return np.zeros(768)  # same embedding size as GraphCodeBERT
    vectors = []
    weights = []
    for chunk in chunks:
        vectors.append(chunk.embedding)
        # Weight = 1 for STRAIGHT, 2 for CONTROL, optionally multiply by depth
        w = 1 if chunk.kind.name == "STRAIGHT" else 2
        w *= (chunk.ast_depth + 1)
        weights.append(w)
    vectors = np.array(vectors)
    weights = np.array(weights).reshape(-1, 1)
    weighted_sum = np.sum(vectors * weights, axis=0)
    summary = weighted_sum / weights.sum()
    return summary  # single vector representing function

In [14]:
def print_chunks(S, C):
    print("STRAIGHT CHUNKS:")
    for chunk in S:
        print(f"- {chunk.text}")

    print("\nCONTROL CHUNKS:")
    for chunk in C:
        print(f"- {chunk.text}")

In [15]:
S_anchor, C_anchor = parse_and_chunk_java(anchor)
print_chunks(S_anchor, C_anchor)
embed_chunks(S_anchor, C_anchor)
#anchor_summary = combine_chunks(S_anchor + C_anchor)
anchor_summary = 0.3 * combine_chunks(S_anchor) + 0.7 * combine_chunks(C_anchor)

S_clone, C_clone = parse_and_chunk_java(clone)
print_chunks(S_clone, C_clone)
embed_chunks(S_clone, C_clone)
#clone_summary = combine_chunks(S_clone + C_clone)
clone_summary = 0.3 * combine_chunks(S_clone) + 0.7 * combine_chunks(C_clone)

S_non, C_non = parse_and_chunk_java(non_clone)
print_chunks(S_non, C_non)
embed_chunks(S_non, C_non)
#non_clone_summary = combine_chunks(S_non + C_non)
non_clone_summary = 0.3 * combine_chunks(S_non) + 0.7 * combine_chunks(C_non)

STRAIGHT CHUNKS:
- System.out.print("Enter an integer: ");
- java.util.Scanner input = new java.util.Scanner(System.in);
- int number = input.nextInt();
- reverse(number);
- int remainder = number % 10;
- System.out.print(remainder);
- number = number / 10;
- System.out.println();

CONTROL CHUNKS:
- while (number != 0) {
			int remainder = number % 10;
			System.out.print(remainder);
			number = number / 10;
		}
STRAIGHT CHUNKS:
- System.out.print("Enter an integer: ");
- java.util.Scanner input = new java.util.Scanner(System.in);
- int n = input.nextInt();
- beautyReverse(n);
- int r = num % 10;
- System.out.print(r);
- num = num / 10;
- System.out.println();

CONTROL CHUNKS:
- while (num != 0)
			{
			int r = num % 10;
				System.out.print(r);
			num = num / 10;
		}
STRAIGHT CHUNKS:
- double sum = 0;
- int i = 0;
- sum += mtx[i][i];
- return sum;
- double[][] mtx = new double[4][4];
- Scanner s = new Scanner(System.in);
- System.out.print("Enter a 4 by 4 matrix row by row: ");
- int 

In [16]:
from numpy.linalg import norm

def cosine_similarity(a, b):
    return np.dot(a, b) / (norm(a) * norm(b))


sim_clone = cosine_similarity(anchor_summary, clone_summary)
sim_non_clone = cosine_similarity(anchor_summary, non_clone_summary)

score = sim_clone / (sim_non_clone + 1e-8)  # higher is better

print(score)
print(sim_clone)
print(sim_non_clone)


1.1389012820380315
0.9789242655999484
0.8595338943329472


In [17]:
import numpy as np
from numpy.linalg import norm

def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    """Compute cosine similarity between two vectors."""
    a_norm = norm(a)
    b_norm = norm(b)
    if a_norm == 0 or b_norm == 0:
        return 0.0
    return np.dot(a, b) / (a_norm * b_norm)


In [18]:
import numpy as np
from tqdm import tqdm  # optional, for progress bar

best_score = -1
best_wS = 0.3

# Use the generator
triplets = list(original_plag_triplet_generator(seed=42))

for wS in np.arange(0.1, 0.9 + 1e-5, 0.05):
    wC = 1.0 - wS
    total_score = 0

    for triplet in tqdm(triplets, desc=f"wS={wS:.2f}", leave=False):
        anchor = triplet["anchor"]
        clone = triplet["clone"]
        non_clone = triplet["nonclone"]

        # 1️⃣ Parse chunks
        S_a, C_a = parse_and_chunk_java(anchor)
        S_c, C_c = parse_and_chunk_java(clone)
        S_nc, C_nc = parse_and_chunk_java(non_clone)

        # 2️⃣ Embed
        embed_chunks(S_a, C_a)
        embed_chunks(S_c, C_c)
        embed_chunks(S_nc, C_nc)

        # 3️⃣ Weighted summary vector
        summary_anchor = wS * combine_chunks(S_a) + wC * combine_chunks(C_a)
        summary_clone  = wS * combine_chunks(S_c) + wC * combine_chunks(C_c)
        summary_non    = wS * combine_chunks(S_nc) + wC * combine_chunks(C_nc)

        # 4️⃣ Cosine similarity ratio
        sim_clone     = cosine_similarity(summary_anchor, summary_clone)
        sim_non_clone = cosine_similarity(summary_anchor, summary_non)
        total_score += sim_clone / (sim_non_clone + 1e-8)

    avg_score = total_score / len(triplets)
    if avg_score > best_score:
        best_score = avg_score
        best_wS = wS    

print(f"Optimal weights: STRAIGHT={best_wS:.2f}, CONTROL={1-best_wS:.2f}, avg_score={best_score:.3f}")


                                                             

Optimal weights: STRAIGHT=0.90, CONTROL=0.10, avg_score=1.063


