In [1]:
# llm_obfuscator_safe.py
# Defensive research pipeline: fine-tune CodeT5 to generate obfuscated *SAFE* JavaScript.
# NOTE: This script intentionally neutralizes dangerous patterns and blocks risky tokens.

import os
import re
import json
import math
import hashlib
import random
import subprocess
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    set_seed,
)

# -----------------------------
# 1) Safety / Neutralization
# -----------------------------
DANGEROUS_PATTERNS = [
    r"\beval\s*\(",
    r"\bFunction\s*\(",
    r"\bdocument\b",
    r"\bwindow\b",
    r"\bcookie\b",
    r"\blocation\b",
    r"<\s*script\b",
    r"\bon\w+\s*=",
    r"javascript\s*:",
    r"\batob\s*\(",
    r"\bdecodeURIComponent\s*\(",
]

BLOCKED_SUBSTRINGS = [
    "eval", "Function", "document", "cookie", "<", ">", "onerror", "onload", "javascript:"
]

def neutralize_js(js: str) -> str:
    """
    Convert input into SAFE JS-like text.
    - Removes/rewrites known dangerous tokens
    - Ensures we're training an obfuscator, not generating exploit payloads
    """
    s = js.strip()
    s = s.replace("\r", " ").replace("\n", " ")
    s = re.sub(r"\s+", " ", s)

    # Rewrite common sink-like identifiers into safe placeholders
    s = re.sub(r"\balert\s*\(", "SAFE_CALL(", s, flags=re.IGNORECASE)
    s = re.sub(r"\bprompt\s*\(", "SAFE_CALL(", s, flags=re.IGNORECASE)
    s = re.sub(r"\bconfirm\s*\(", "SAFE_CALL(", s, flags=re.IGNORECASE)

    # Strip angle brackets to avoid HTML/JS injection forms
    s = s.replace("<", " ").replace(">", " ")

    # Remove dangerous patterns
    for pat in DANGEROUS_PATTERNS:
        s = re.sub(pat, "SAFE_BLOCKED(", s, flags=re.IGNORECASE)

    # Final cleanup
    s = re.sub(r"\s+", " ", s).strip()
    return s

# -----------------------------
# 2) Obfuscation transforms (safe)
#    These mirror the paper's idea but remain non-executable / neutralized.
# -----------------------------
def obf_string_split(s: str) -> str:
    # Split into chunks and concatenate (classic string-splitting idea)
    if len(s) < 8:
        return f'"{s}"'
    k = random.randint(3, 6)
    parts = []
    step = max(1, len(s) // k)
    for i in range(0, len(s), step):
        parts.append(s[i:i+step])
    parts = [p.replace('"', '\\"') for p in parts]
    return " + ".join([f'"{p}"' for p in parts])

def obf_whitespace_noise(js: str) -> str:
    # Insert benign whitespace/no-op comments
    tokens = js.split(" ")
    out = []
    for t in tokens:
        out.append(t)
        if random.random() < 0.25:
            out.append("/*noop*/")
        if random.random() < 0.15:
            out.append(" ")
    return " ".join(out)

def obf_keyword_rename_like(js: str) -> str:
    # Very light, safe "rename-like" obfuscation (doesn't change semantics, just adds aliases)
    # We avoid true renaming to keep it simple & safe.
    alias = "v" + str(random.randint(1000, 9999))
    return f"(function({alias}){{ return {alias}; }})({json.dumps(js)})"

def make_obfuscated_variant(js_safe: str) -> str:
    # Randomly apply 1-3 transforms
    transforms = [obf_whitespace_noise, obf_keyword_rename_like]
    out = js_safe
    random.shuffle(transforms)
    for fn in transforms[: random.randint(1, 2)]:
        out = fn(out)

    # Optional string-splitting wrapper to obscure plain text
    if random.random() < 0.6:
        out = f"const PAYLOAD = {obf_string_split(out)}; /*SAFE*/"
    return out

# -----------------------------
# 3) Optional: JS parse validation (AST)
# -----------------------------
def js_parses_ok(js: str) -> bool:
    """
    Optional validation using node + esprima.
    Install:
      npm i esprima
    Then ensure `node` is available.
    """
    try:
        # Parse as a script; we wrap in a harmless context.
        payload = f"const x = 1; {js}"
        cmd = ["node", "-e", "const esprima=require('esprima'); esprima.parseScript(process.argv[1]);", payload]
        subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    except Exception:
        return False

# -----------------------------
# 4) Dataset builder
# -----------------------------
def build_pairs(benign_js_list: List[str], n_per_sample: int = 2, require_parse: bool = False) -> Dataset:
    rows = []
    for js in benign_js_list:
        src = neutralize_js(js)
        if not src:
            continue
        for _ in range(n_per_sample):
            tgt = make_obfuscated_variant(src)
            if require_parse and (not js_parses_ok(tgt)):
                continue
            rows.append({"src": src, "tgt": tgt})

    # Dedupe
    seen = set()
    deduped = []
    for r in rows:
        h = hashlib.sha256((r["src"] + "||" + r["tgt"]).encode("utf-8")).hexdigest()
        if h in seen:
            continue
        seen.add(h)
        deduped.append(r)

    return Dataset.from_list(deduped)

# -----------------------------
# 5) Tokenization
# -----------------------------
def tokenize_function(examples, tokenizer, max_src=256, max_tgt=256):
    model_inputs = tokenizer(
        examples["src"],
        max_length=max_src,
        truncation=True,
    )
    labels = tokenizer(
        text_target=examples["tgt"],
        max_length=max_tgt,
        truncation=True,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# -----------------------------
# 6) Training
# -----------------------------
def train_obfuscator(
    benign_js_list: List[str],
    model_name: str = "Salesforce/codet5-small",
    out_dir: str = "./codet5_safe_obfuscator",
    seed: int = 42,
    require_parse: bool = False,
):
    set_seed(seed)

    ds = build_pairs(benign_js_list, n_per_sample=3, require_parse=require_parse)
    ds = ds.train_test_split(test_size=0.1, seed=seed)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    tokenized_train = ds["train"].map(lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=ds["train"].column_names)
    tokenized_eval  = ds["test"].map(lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=ds["test"].column_names)

    collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    args = Seq2SeqTrainingArguments(
    output_dir=out_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    save_total_limit=2,
    report_to="none",
    )
    
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        data_collator=collator,
        tokenizer=tokenizer,
    )

    trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)
    print(f"Saved to: {out_dir}")

# -----------------------------
# 7) Safe generation with filtering (temperature/top_p)
# -----------------------------
def safe_generate(
    text: str,
    ckpt_dir: str = "./codet5_safe_obfuscator",
    num_samples: int = 5,
    temperature: float = 1.2,
    top_p: float = 0.95,
    max_new_tokens: int = 180,
) -> List[str]:
    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
    model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir)
    model.eval()

    src = neutralize_js(text)
    inputs = tokenizer(src, return_tensors="pt", truncation=True, max_length=256)

    outputs = []
    with torch.no_grad():
        for _ in range(num_samples):
            gen_ids = model.generate(
                **inputs,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                num_beams=1,
            )
            s = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

            # Block risky substrings defensively
            low = s.lower()
            if any(b.lower() in low for b in BLOCKED_SUBSTRINGS):
                continue

            # Remove any accidental brackets again
            s = s.replace("<", " ").replace(">", " ")
            s = re.sub(r"\s+", " ", s).strip()

            outputs.append(s)

    # Dedupe
    uniq = []
    seen = set()
    for o in outputs:
        h = hashlib.md5(o.encode("utf-8")).hexdigest()
        if h not in seen:
            seen.add(h)
            uniq.append(o)
    return uniq

# -----------------------------
# 8) Example usage
# -----------------------------
if __name__ == "__main__":
    # Provide benign JS snippets only (or your "neutralized" dataset)
    benign_examples = [
        "function add(a,b){ return a+b; } console.log(add(2,3));",
        "const msg = 'hello world'; console.log(msg.toUpperCase());",
        "if (x > 10) { console.log('big'); } else { console.log('small'); }",
    ]

    # 1) Train
    train_obfuscator(benign_examples, require_parse=False)

    # 2) Generate safe obfuscations
    outs = safe_generate("console.log('demo');", num_samples=6, temperature=1.3, top_p=0.9)
    print("\nGenerated SAFE obfuscations:")
    for i, o in enumerate(outs, 1):
        print(f"{i}. {o}")







Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss
1,No log,5.377659
2,No log,4.781147
3,No log,4.781147


Saved to: ./codet5_safe_obfuscator

Generated SAFE obfuscations:
1. console.log('demo',console.colors.green(demo.toString().join('|'))));console.log(demo.name).slice(0, 5).join('')
2. // console.log('Hello world...'); // console.log('Launching...');//console.log('Hello World World?');
3. { t.main.test} ( 'demo:demo'): console.log(template); });console.log('demodemo');
4. console.log('demo'); console.log('demo'); console.log('demo'); console.log('demo'); console.log('demo'); console.log('demo'); console.log('demo');
