In [1]:
import sys
import json
import os
import torch
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
import scipy
from scipy.spatial.distance import euclidean
from scipy.signal import savgol_filter
from dataclasses import dataclass
from enum import Enum

from transformers import pipeline
from sklearn.decomposition import PCA
import javalang
import tree_sitter_java
from tree_sitter import Language, Parser

sys.path.append(os.path.abspath(".."))

from data.dataset_factory import get_dataset_generator
from data.data_generators.sourcecodeplag_dataset_gen import original_plag_triplet_generator
from preprocessing.embedding_chunks import get_ready_to_embed_chunks
from preprocessing.context_chunker import safe_get_ready_to_embed_context_chunks
from preprocessing.mean_pool_chunks import mean_pool_chunks
from preprocessing.block_splitter import deverbose_ast
from visualizer.smoothing import smooth_embeddings, smooth_multiple_embeddings

  from .autonotebook import tqdm as notebook_tqdm


# Tree Sitter for parsing

In [2]:
import tree_sitter_java
from tree_sitter import Language, Parser

JAVA_LANGUAGE = Language(tree_sitter_java.language())
parser = Parser(JAVA_LANGUAGE)

In [3]:
model_name = "microsoft/graphcodebert-base"
pipe = pipeline("feature-extraction", model=model_name)

Loading weights: 100%|██████████| 197/197 [00:00<00:00, 1837.93it/s, Materializing param=encoder.layer.11.output.dense.weight]              
RobertaModel LOAD REPORT from: microsoft/graphcodebert-base
Key                             | Status     | 
--------------------------------+------------+-
lm_head.decoder.weight          | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
lm_head.decoder.bias            | UNEXPECTED | 
lm_head.dense.bias              | UNEXPECTED | 
lm_head.layer_norm.weight       | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
pooler.dense.bias               | MISSING    | 
pooler.dense.weight             | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training 

# Data and representation

In [4]:
class ChunkKind(Enum):
    STRAIGHT = "straight"
    CONTROL = "control"

@dataclass
class Chunk:
    text: str
    embedding: None
    ast_depth: int
    kind: ChunkKind


CONTROL_NODES = (
    "if_statement",
    "for_statement",
    "while_statement",
    "do_statement",
    "switch_statement",
    "try_statement",
    "catch_clause",
)

STRAIGHT_NODES = (
    "local_variable_declaration",
    "expression_statement",
    "return_statement",
    "throw_statement",
)

# 

In [5]:
def extract_control_header(code: bytes, node):
    """
    Extracts only the control header (condition / signature),
    excluding the body.
    """
    # Common Tree-sitter pattern: condition is a child
    for child in node.children:
        if child.type in ("condition", "parenthesized_expression"):
            return code[child.start_byte:child.end_byte].decode("utf8")

    # Fallback: first line only
    text = code[node.start_byte:node.end_byte].decode("utf8")
    return text.split("{")[0].strip()


In [6]:
class JavaChunkExtractor:
    def __init__(self, source_code):
        self.source = source_code
        self.lines = source_code.splitlines()
        self.chunks = []

    def visit(self, node, depth=0):
        if isinstance(node, CONTROL_NODES):
            self._add_chunk(node, depth, ChunkKind.CONTROL)

        elif isinstance(node, STRAIGHT_NODES):
            self._add_chunk(node, depth, ChunkKind.STRAIGHT)

        for _, child in node.filter(javalang.tree.Node):
            self.visit(child, depth + 1)

In [7]:
class TreeSitterJavaChunker:
    def __init__(self, source_code: str):
        self.code = source_code.encode("utf8")
        self.chunks = []

    def extract(self):
        tree = parser.parse(self.code)
        self._visit(tree.root_node, depth=0)
        return self.chunks

    def _visit(self, node, depth):
        if node.type in CONTROL_NODES:
            self._add_control_chunk(node, depth)

        elif node.type in STRAIGHT_NODES:
            self._add_statement_chunk(node, depth)

        for child in node.children:
            if len(child.children) == 0 and child.type in (';', '{', '}', '(', ')', ','):
                continue
            self._visit(child, depth + 1)

    def _add_statement_chunk(self, node, depth):
        text = self.code[node.start_byte:node.end_byte].decode("utf8").strip()
        if not text:
            return

        self.chunks.append(
            Chunk(
                text=text,
                embedding=None,
                ast_depth=depth,
                kind=ChunkKind.STRAIGHT
            )
        )

    def _add_control_chunk(self, node, depth):
        text = extract_control_header(self.code, node)
        if not text:
            return

        self.chunks.append(
            Chunk(
                text=text,
                embedding=None,
                ast_depth=depth,
                kind=ChunkKind.CONTROL
            )
        )


def control_weight(depth, cap=8):
    w = 1.0 + torch.log(torch.tensor(depth + 1.0))

    return w
    #return min(np.log(depth + 1), cap)


In [8]:
def parse_and_chunk_java(code: str):
    chunker = TreeSitterJavaChunker(code)
    chunks = chunker.extract()

    S_chunks = [c for c in chunks if c.kind == ChunkKind.STRAIGHT]
    C_chunks = [c for c in chunks if c.kind == ChunkKind.CONTROL]

    return S_chunks, C_chunks


def embed_chunk_text(text):
    """Embed a single code chunk as a fixed vector"""
    outputs = pipe(text)            # shape: [1, seq_len, hidden_dim]
    arr = np.array(outputs[0])      # shape: [seq_len, hidden_dim]
    return arr.mean(axis=0)         # mean over tokens → shape: (hidden_dim,)


def embed_chunks(S, C):
    for chunk in S + C:
        text = chunk.text.strip()
        if text == "":
            continue
        chunk.embedding = embed_chunk_text(text)


def combine_chunks_torch(chunks, device="cpu"):
    if len(chunks) == 0:
        return torch.zeros(768, device=device)

    vectors, weights = [], []

    for chunk in chunks:
        vec = torch.tensor(chunk.embedding, dtype=torch.float32, device=device)
        vectors.append(vec)


        # BIAS mot kontroll blokker
        if chunk.kind == ChunkKind.CONTROL:
            w = control_weight(chunk.ast_depth)
        else:
            w = 1.0

        weights.append(w)

    V = torch.stack(vectors)
    W = torch.tensor(weights, device=device).unsqueeze(1)

    return (V * W).sum(dim=0) / W.sum()



def cosine_sim(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Differentiable cosine similarity between two vectors.
    Returns a scalar tensor.
    """
    return torch.dot(a, b) / (a.norm() * b.norm() + eps)


def triplet_loss(sim_pos: torch.Tensor,
                 sim_neg: torch.Tensor,
                 margin: float = 0.2) -> torch.Tensor:
    """
    Margin-based triplet loss on cosine similarities.
    Encourages sim_pos >= sim_neg + margin.
    """
    return torch.relu(margin + sim_neg - sim_pos)


In [9]:
def contrastive_loss(sim_pos, sim_neg, margin=0.5):
    """
    Cosine-based contrastive loss for triplets.
    """
    pos_loss = (1.0 - sim_pos).pow(2)
    neg_loss = torch.relu(sim_neg - margin).pow(2)
    return pos_loss + neg_loss

In [10]:
alpha = torch.nn.Parameter(torch.tensor(0.5))  # structure vs content
optimizer = torch.optim.Adam([alpha], lr=0.02)

triplets = list(original_plag_triplet_generator(seed=42))

patience = 5
min_delta = 1e-4
best_loss = float("inf")
epochs_no_improve = 0
best_alpha = None

In [11]:
from tqdm import tqdm

embedding_cache = {}

for triplet in tqdm(triplets):
    for key in ["anchor", "clone", "nonclone"]:
        code = triplet[key]

        if code in embedding_cache:
            continue

        S, C = parse_and_chunk_java(code)
        embed_chunks(S, C)

        embedding_cache[code] = {
            "S": combine_chunks_torch(S).detach(),
            "C": combine_chunks_torch(C).detach(),
        }

100%|██████████| 355/355 [01:00<00:00,  5.86it/s]


In [12]:
for epoch in range(30):
    total_loss = 0.0

    for triplet in triplets:
        A_S = embedding_cache[triplet["anchor"]]["S"]
        A_C = embedding_cache[triplet["anchor"]]["C"]

        P_S = embedding_cache[triplet["clone"]]["S"]
        P_C = embedding_cache[triplet["clone"]]["C"]

        N_S = embedding_cache[triplet["nonclone"]]["S"]
        N_C = embedding_cache[triplet["nonclone"]]["C"]

        g = torch.sigmoid(alpha)  # g ∈ (0,1)

        A = (1 - g) * A_S + g * A_C
        P = (1 - g) * P_S + g * P_C
        N = (1 - g) * N_S + g * N_C

        sim_pos = cosine_sim(A, P)
        sim_neg = cosine_sim(A, N)

        loss = torch.relu(0.2 + sim_neg - sim_pos)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(triplets)

    # ---- Early stopping logic ----
    if avg_loss < best_loss - min_delta:
        best_loss = avg_loss
        best_alpha = alpha.detach().clone()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    g_val = torch.sigmoid(alpha).item()

    print(
        f"Epoch {epoch}: loss={avg_loss:.4f}, "
        f"wS={1-g_val:.3f}, wC={g_val:.3f}"
    )

    if epochs_no_improve >= patience:
        print(
            f"Early stopping at epoch {epoch}. "
            f"Best loss={best_loss:.4f}"
        )
        break

# Restore best alpha (important!)
if best_alpha is not None:
    alpha.data = best_alpha


Epoch 0: loss=0.0736, wS=0.038, wC=0.962
Epoch 1: loss=0.0659, wS=0.018, wC=0.982
Epoch 2: loss=0.0652, wS=0.011, wC=0.989
Epoch 3: loss=0.0649, wS=0.007, wC=0.993
Epoch 4: loss=0.0648, wS=0.005, wC=0.995
Epoch 5: loss=0.0647, wS=0.004, wC=0.996
Epoch 6: loss=0.0646, wS=0.003, wC=0.997
Epoch 7: loss=0.0646, wS=0.002, wC=0.998
Epoch 8: loss=0.0646, wS=0.002, wC=0.998
Epoch 9: loss=0.0646, wS=0.001, wC=0.999
Epoch 10: loss=0.0646, wS=0.001, wC=0.999
Epoch 11: loss=0.0645, wS=0.001, wC=0.999
Early stopping at epoch 11. Best loss=0.0646
