In [37]:
import sys
import json
import os
import torch
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
import scipy
from scipy.spatial.distance import euclidean
from scipy.signal import savgol_filter
from dataclasses import dataclass
from enum import Enum

from transformers import pipeline
from sklearn.decomposition import PCA
import javalang
import tree_sitter_java
from tree_sitter import Language, Parser

sys.path.append(os.path.abspath(".."))

from data.dataset_factory import get_dataset_generator
from data.data_generators.sourcecodeplag_dataset_gen import original_plag_triplet_generator
from preprocessing.embedding_chunks import get_ready_to_embed_chunks
from preprocessing.context_chunker import safe_get_ready_to_embed_context_chunks
from preprocessing.mean_pool_chunks import mean_pool_chunks
from preprocessing.block_splitter import deverbose_ast
from visualizer.smoothing import smooth_embeddings, smooth_multiple_embeddings

# Tree Sitter for parsing

In [38]:
import tree_sitter_java
from tree_sitter import Language, Parser

JAVA_LANGUAGE = Language(tree_sitter_java.language())
parser = Parser(JAVA_LANGUAGE)

In [39]:
model_name = "microsoft/graphcodebert-base"
pipe = pipeline("feature-extraction", model=model_name)

Loading weights: 100%|██████████| 197/197 [00:00<00:00, 1349.06it/s, Materializing param=encoder.layer.11.output.dense.weight]              
RobertaModel LOAD REPORT from: microsoft/graphcodebert-base
Key                             | Status     | 
--------------------------------+------------+-
lm_head.decoder.weight          | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.decoder.bias            | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
lm_head.dense.bias              | UNEXPECTED | 
lm_head.layer_norm.weight       | UNEXPECTED | 
pooler.dense.bias               | MISSING    | 
pooler.dense.weight             | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training 

# Data and representation

In [40]:
class ChunkKind(Enum):
    STRAIGHT = "straight"
    CONTROL = "control"

@dataclass
class Chunk:
    text: str
    embedding: None
    ast_depth: int
    kind: ChunkKind


CONTROL_NODES = (
    "if_statement",
    "for_statement",
    "while_statement",
    "do_statement",
    "switch_statement",
    "try_statement",
    "catch_clause",
)

STRAIGHT_NODES = (
    "local_variable_declaration",
    "expression_statement",
    "return_statement",
    "throw_statement",
)

# 

In [41]:
def extract_control_header(code: bytes, node):
    """
    Extracts only the control header (condition / signature),
    excluding the body.
    """
    # Common Tree-sitter pattern: condition is a child
    for child in node.children:
        if child.type in ("condition", "parenthesized_expression"):
            return code[child.start_byte:child.end_byte].decode("utf8")

    # Fallback: first line only
    text = code[node.start_byte:node.end_byte].decode("utf8")
    return text.split("{")[0].strip()


In [42]:
class JavaChunkExtractor:
    def __init__(self, source_code):
        self.source = source_code
        self.lines = source_code.splitlines()
        self.chunks = []

    def visit(self, node, depth=0):
        if isinstance(node, CONTROL_NODES):
            self._add_chunk(node, depth, ChunkKind.CONTROL)

        elif isinstance(node, STRAIGHT_NODES):
            self._add_chunk(node, depth, ChunkKind.STRAIGHT)

        for _, child in node.filter(javalang.tree.Node):
            self.visit(child, depth + 1)

In [43]:
class TreeSitterJavaChunker:
    def __init__(self, source_code: str):
        self.code = source_code.encode("utf8")
        self.chunks = []

    def extract(self):
        tree = parser.parse(self.code)
        self._visit(tree.root_node, depth=0)
        return self.chunks

    def _visit(self, node, depth):
        if node.type in CONTROL_NODES:
            self._add_control_chunk(node, depth)

        elif node.type in STRAIGHT_NODES:
            self._add_statement_chunk(node, depth)

        for child in node.children:
            if len(child.children) == 0 and child.type in (';', '{', '}', '(', ')', ','):
                continue
            self._visit(child, depth + 1)

    def _add_statement_chunk(self, node, depth):
        text = self.code[node.start_byte:node.end_byte].decode("utf8").strip()
        if not text:
            return

        self.chunks.append(
            Chunk(
                text=text,
                embedding=None,
                ast_depth=depth,
                kind=ChunkKind.STRAIGHT
            )
        )

    def _add_control_chunk(self, node, depth):
        text = extract_control_header(self.code, node)
        if not text:
            return

        self.chunks.append(
            Chunk(
                text=text,
                embedding=None,
                ast_depth=depth,
                kind=ChunkKind.CONTROL
            )
        )


def control_weight(depth, cap=8):
    w = 1.0 + torch.log(torch.tensor(depth + 1.0))

    return w
    #return min(np.log(depth + 1), cap)


In [44]:
def parse_and_chunk_java(code: str):
    chunker = TreeSitterJavaChunker(code)
    chunks = chunker.extract()

    S_chunks = [c for c in chunks if c.kind == ChunkKind.STRAIGHT]
    C_chunks = [c for c in chunks if c.kind == ChunkKind.CONTROL]

    return S_chunks, C_chunks


def embed_chunk_text(text):
    """Embed a single code chunk as a fixed vector"""
    outputs = pipe(text)            # shape: [1, seq_len, hidden_dim]
    arr = np.array(outputs[0])      # shape: [seq_len, hidden_dim]
    return arr.mean(axis=0)         # mean over tokens → shape: (hidden_dim,)


def embed_chunks(S, C):
    for chunk in S + C:
        text = chunk.text.strip()
        if text == "":
            continue
        chunk.embedding = embed_chunk_text(text)


def combine_chunks_torch(chunks, device="cpu"):
    if len(chunks) == 0:
        return torch.zeros(768, device=device)

    vectors, weights = [], []

    for chunk in chunks:
        vec = torch.tensor(chunk.embedding, dtype=torch.float32, device=device)
        vectors.append(vec)


        # BIAS mot kontroll blokker
        if chunk.kind == ChunkKind.CONTROL:
            w = control_weight(chunk.ast_depth)
        else:
            w = 1.0

        weights.append(w)

    V = torch.stack(vectors)
    W = torch.tensor(weights, device=device).unsqueeze(1)

    return (V * W).sum(dim=0) / W.sum()



def cosine_sim(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Differentiable cosine similarity between two vectors.
    Returns a scalar tensor.
    """
    return torch.dot(a, b) / (a.norm() * b.norm() + eps)


def triplet_loss(sim_pos: torch.Tensor,
                 sim_neg: torch.Tensor,
                 margin: float = 0.2) -> torch.Tensor:
    """
    Margin-based triplet loss on cosine similarities.
    Encourages sim_pos >= sim_neg + margin.
    """
    return torch.relu(margin + sim_neg - sim_pos)


In [45]:
def contrastive_loss(sim_pos, sim_neg, margin=0.5):
    """
    Cosine-based contrastive loss for triplets.
    """
    pos_loss = (1.0 - sim_pos).pow(2)
    neg_loss = torch.relu(sim_neg - margin).pow(2)
    return pos_loss + neg_loss

In [46]:
alpha = torch.nn.Parameter(torch.tensor(0.5))  # structure vs content
optimizer = torch.optim.Adam([alpha], lr=0.02)

#triplets = list(original_plag_triplet_generator(seed=42))

patience = 5
min_delta = 1e-4
best_loss = float("inf")
epochs_no_improve = 0
best_alpha = None

In [47]:
from tqdm import tqdm

embedding_cache = {}

"""
for triplet in tqdm(triplets):
    for key in ["anchor", "clone", "nonclone"]:
        code = triplet[key]

        if code in embedding_cache:
            continue

        S, C = parse_and_chunk_java(code)
        embed_chunks(S, C)

        embedding_cache[code] = {
            "S": combine_chunks_torch(S).detach(),
            "C": combine_chunks_torch(C).detach(),
        }
"""

'\nfor triplet in tqdm(triplets):\n    for key in ["anchor", "clone", "nonclone"]:\n        code = triplet[key]\n\n        if code in embedding_cache:\n            continue\n\n        S, C = parse_and_chunk_java(code)\n        embed_chunks(S, C)\n\n        embedding_cache[code] = {\n            "S": combine_chunks_torch(S).detach(),\n            "C": combine_chunks_torch(C).detach(),\n        }\n'

In [48]:
import torch
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm


# Cache dictionary: key = chunk text, value = embedding tensor
embedding_cache = {}

def embed_chunk_text_cached(text):
    text = text.strip()
    if text == "":
        return torch.zeros(768, dtype=torch.float32)

    if text in embedding_cache:
        return embedding_cache[text]


    outputs = pipe(text, truncation=True, max_length=512)
    arr = np.array(outputs[0])        # shape: [seq_len, hidden_dim]
    vec = torch.tensor(arr.mean(axis=0), dtype=torch.float32)  # mean pooling
    embedding_cache[text] = vec
    return vec


def embed_chunk_texts_batched(texts, batch_size=16):
    """
    Embed a list of chunk texts, using cache + batching.
    Returns a dict {text: embedding_tensor}
    """
    results = {}

    # Filter out cached and empty
    to_embed = [t for t in texts if t.strip() and t not in embedding_cache]

    for i in range(0, len(to_embed), batch_size):
        batch = to_embed[i:i + batch_size]
        outputs = pipe(batch, truncation=True, max_length=512)

        for text, out in zip(batch, outputs):
            arr = np.array(out)               # [seq_len, hidden_dim]
            vec = torch.tensor(
                arr.mean(axis=0),
                dtype=torch.float32
            )
            embedding_cache[text] = vec
            results[text] = vec

    # Add cached ones
    for t in texts:
        if t.strip():
            results[t] = embedding_cache[t]
        else:
            results[t] = torch.zeros(768, dtype=torch.float32)

    return results


def embed_chunks_cached_batched(S, C, batch_size=16):
    texts = [c.text for c in S + C]
    embeddings = embed_chunk_texts_batched(texts, batch_size=batch_size)

    for chunk in S + C:
        emb = embeddings[chunk.text]

        # SAFETY: force shape [768]
        if emb.ndim == 2:
            emb = emb.mean(dim=0)

        chunk.embedding = emb



def embed_chunks_cached(S, C):
    for chunk in S + C:
        if chunk.text.strip() == "":
            continue
        chunk.embedding = embed_chunk_text_cached(chunk.text)


# --- Helpers ---
def combine_chunks_torch(chunks, device="cpu"):
    if len(chunks) == 0:
        return torch.zeros(768, dtype=torch.float64, device=device)

    vectors, weights = [], []

    for chunk in chunks:
        

        vec = torch.tensor(chunk.embedding, dtype=torch.float64, device=device)  # <- float64

        assert vec.ndim == 1 and vec.shape[0] == 768, f"Bad embedding shape: {vec.shape}"

        vectors.append(vec)

        if chunk.kind == ChunkKind.CONTROL:
            w = np.log(chunk.ast_depth + 2)
        else:
            w = 1.0
        weights.append(w)

    V = torch.stack(vectors)
    W = torch.tensor(weights, dtype=torch.float64, device=device).unsqueeze(1)

    return (V * W).sum(dim=0) / W.sum()


def cosine_sim(a: torch.Tensor, b: torch.Tensor, eps=1e-8):
    # Convert both to float32
    a = a.float()
    b = b.float()
    return torch.dot(a, b) / (a.norm() * b.norm() + eps)


def contrastive_pair_loss(sim, label, margin=0.5):
    """
    Contrastive loss for pairwise data
    sim: cosine similarity ∈ [-1,1]
    label: 1=clone, 0=non-clone
    """
    # want sim high if label=1, low if label=0
    pos_loss = (1 - sim)**2 * label
    neg_loss = F.relu(sim - margin)**2 * (1 - label)
    return pos_loss + neg_loss

# --- Learnable weight ---
alpha = torch.nn.Parameter(torch.tensor(0.5))  # gate between statements vs control
optimizer = torch.optim.Adam([alpha], lr=0.02)

# --- Dataset generator ---
from datasets import load_dataset
from data.data_generators.schema import CodeSample

def bigclonebench_generator(split="train"):
    ds = load_dataset("google/code_x_glue_cc_clone_detection_big_clone_bench", split=split)
    for sample in ds:
        label = 1 if int(sample["label"]) == 0 else 0  # 1=clone, 0=non-clone
        yield CodeSample(code_a=sample["func1"], code_b=sample["func2"], label=label, dataset="bigclonebench")

# --- Training loop ---
num_epochs = 30
best_loss = float("inf")
patience = 5
stop_counter = 0

mini_epoch_size = 2000     # start small (10k–50k)
batch_size = 16
num_epochs = 30


import random
from tqdm import tqdm

# --- Parameters ---
mini_epoch_size = 500  # number of samples per mini-epoch
batch_size = 16          # for embedding batching
num_mini_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

alpha = torch.nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32, device=device))
optimizer = torch.optim.Adam([alpha], lr=0.02)
lambda_reg = 0.1
margin = 0.5

# --- Multi-channel loss function ---
def multi_channel_loss(f_s, f_c, sim_label):
    w_s, w_c = torch.sigmoid(alpha)
    f_combined = w_s * f_s + w_c * f_c
    sim = torch.dot(f_combined, f_combined) / (f_combined.norm()**2 + 1e-8)

    pos_loss = ((1 - sim) ** 2) * sim_label
    neg_loss = F.relu(sim - margin) ** 2 * (1 - sim_label)
    base_loss = pos_loss + neg_loss

    reg_loss = lambda_reg * ((w_s - 0.5)**2 + (w_c - 0.5)**2)
    return base_loss + reg_loss, w_s.item(), w_c.item()


# --- Prepare dataset ---
all_samples = list(bigclonebench_generator(split="train"))

# --- Training loop with mini-epochs ---
for mini_epoch in range(num_mini_epochs):
    mini_samples = random.sample(all_samples, mini_epoch_size)
    total_loss = 0.0

    for sample in tqdm(mini_samples, desc=f"Mini-epoch {mini_epoch+1}"):
        # Parse & chunk
        S_a, C_a = parse_and_chunk_java(sample.code_a)
        S_b, C_b = parse_and_chunk_java(sample.code_b)

        # Embed (use caching and batching for speed)
        embed_chunks_cached(S_a, C_a)
        embed_chunks_cached(S_b, C_b)

        # Aggregate embeddings
        f_s = combine_chunks_torch(S_a, device=device)
        f_c = combine_chunks_torch(C_a, device=device)
        f_s_b = combine_chunks_torch(S_b, device=device)
        f_c_b = combine_chunks_torch(C_b, device=device)

        # Combine statement/control embeddings across the pair
        f_s_pair = f_s + f_s_b
        f_c_pair = f_c + f_c_b

        sim_label = torch.tensor(sample.label, dtype=torch.float32, device=device)

        # Compute loss with multi-channel regularization
        loss, w_s_val, w_c_val = multi_channel_loss(f_s_pair, f_c_pair, sim_label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / mini_epoch_size
    print(f"Mini-epoch {mini_epoch+1} - Avg Loss: {avg_loss:.4f} | wS={w_s_val:.3f} wC={w_c_val:.3f}")


  vec = torch.tensor(chunk.embedding, dtype=torch.float64, device=device)  # <- float64
Mini-epoch 1: 100%|██████████| 500/500 [04:39<00:00,  1.79it/s]


Mini-epoch 1 - Avg Loss: 0.1231 | wS=0.500 wC=0.500


Mini-epoch 2: 100%|██████████| 500/500 [03:15<00:00,  2.56it/s]


Mini-epoch 2 - Avg Loss: 0.1220 | wS=0.500 wC=0.500


Mini-epoch 3: 100%|██████████| 500/500 [03:03<00:00,  2.72it/s]


Mini-epoch 3 - Avg Loss: 0.1240 | wS=0.500 wC=0.500


Mini-epoch 4: 100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


Mini-epoch 4 - Avg Loss: 0.1270 | wS=0.500 wC=0.500


Mini-epoch 5: 100%|██████████| 500/500 [01:44<00:00,  4.79it/s]

Mini-epoch 5 - Avg Loss: 0.1325 | wS=0.500 wC=0.500



