### STEP 3: AUGMENTING .JSON FILES


Safety Enhancements:
=> Core function protection - Never rename crypto core functions

Dependency-aware reordering - Better commutative operation safety


=>  Sampling:
Weighted strategy selection - Prioritize high-impact augmentations

Balanced strategy distribution - Avoid overusing any single strategy


In [4]:
import os
import json
import random
from typing import Dict, Any, List, Tuple


# ----------------------------
# Unified PDV Processor
# ----------------------------
class UnifiedPDVProcessor:
    def __init__(self):
        self.feature_names = [
            # Core parameters
            'block_size', 'key_size', 'rounds',
            # Cipher family indicators
            'is_feistel', 'is_arx', 'is_spn', 
            # Operation counts (static - per round definition)
            'xor_count', 'rotl_count', 'rotr_count', 'add_count', 'sub_count', 
            'and_count', 'sbox_count', 'perm_count',
            # Structural complexity
            'round_complexity', 'rotation_diversity', 'max_rotation_amount',
            # Core function presence
            'has_round_function', 'has_key_schedule', 'has_f_function', 
            'has_enc_round', 'has_dec_round',
            # Graph statistics
            'ast_node_count', 'ast_edge_count', 'function_count',
            # Cryptographic properties
            'uses_z_sequence', 'uses_shift_params', 'uses_sbox', 'uses_permutation',
            # === ENRICHED FEATURES ===
            'round_function_size', 'operations_per_round',
            'complexity_ratio', 'estimated_total_operations',
            'key_schedule_operations', 'encryption_operations', 'decryption_operations'
        ]
                # Define core round functions for each family
        self.core_round_functions = {
            'Feistel': ['F_function', 'simon_round'],
            'ARX': ['speck_enc_round', 'speck_dec_round'],
            'SPN': ['present_round', 'sbox_layer', 'p_layer']
        }
    
    def create_unified_pdv(self, extracted_pdv: Dict, ast_data: Dict) -> Dict[str, Any]:
        """Convert family-specific PDV to unified schema using core round functions"""
        
        cipher_family = extracted_pdv.get("cipher_family", "").lower()
        is_feistel = 1 if "feistel" in cipher_family else 0
        is_arx = 1 if "arx" in cipher_family else 0
        is_spn = 1 if "spn" in cipher_family else 0
        
        # Get operation counts and structure
        ops_summary = extracted_pdv.get("ops_summary", {})
        
        # Extract features from the CORRECT structure based on cipher family
        if is_feistel:
            structure = extracted_pdv.get("feistel_structure", {})
        elif is_arx:
            structure = extracted_pdv.get("arx_structure", {})
        elif is_spn:
            structure = extracted_pdv.get("spn_structure", {})
        else:
            structure = {}
        
        # Calculate enriched features using core round functions
        rounds = extracted_pdv.get("rounds", 0)
        enriched_features = self._compute_enriched_features(ops_summary, structure, rounds, cipher_family)
        
        unified = {
            # Core parameters
            "block_size": extracted_pdv["block_size"],
            "key_size": extracted_pdv["key_size"], 
            "rounds": rounds,
            
            # Family detection
            "is_feistel": is_feistel,
            "is_arx": is_arx,
            "is_spn": is_spn,
            
            # Operation counts (static - per round definition)
            "xor_count": ops_summary.get("xor_count", 0),
            "rotl_count": ops_summary.get("rotl_count", 0),
            "rotr_count": ops_summary.get("rotr_count", 0),
            "add_count": ops_summary.get("add_count", 0),
            "sub_count": ops_summary.get("sub_count", 0),
            "and_count": ops_summary.get("and_count", 0),
            "sbox_count": ops_summary.get("sbox_count", 0),
            "perm_count": ops_summary.get("perm_count", 0),
            
            # Structural complexity
            "round_complexity": self._extract_round_complexity(extracted_pdv, cipher_family),
            "rotation_diversity": self._extract_rotation_diversity(extracted_pdv, cipher_family),
            "max_rotation_amount": self._extract_max_rotation(extracted_pdv, cipher_family),
            
            # Core function presence
            "has_round_function": self._extract_has_round_function(extracted_pdv, cipher_family),
            "has_f_function": self._extract_has_f_function(extracted_pdv, cipher_family),
            "has_enc_round": self._extract_has_enc_round(extracted_pdv, cipher_family),
            "has_dec_round": self._extract_has_dec_round(extracted_pdv, cipher_family),
            "has_key_schedule": self._extract_has_key_schedule(extracted_pdv, cipher_family),
            
            # Graph statistics
            "ast_node_count": len(ast_data.get("nodes", [])),
            "ast_edge_count": len(ast_data.get("edges", [])),
            "function_count": len(ast_data.get("functions", [])),
            
            # Cryptographic properties
            "uses_z_sequence": ops_summary.get("z_seq_usage", 0),
            "uses_shift_params": self._extract_uses_shift_params(extracted_pdv, cipher_family),
            "uses_sbox": 1 if ops_summary.get("sbox_count", 0) > 0 else 0,
            "uses_permutation": 1 if ops_summary.get("perm_count", 0) > 0 else 0,
            
            # === ENRICHED FEATURES ===
            **enriched_features
        }
        
        return unified

    def _compute_enriched_features(self, ops_summary: Dict, structure: Dict, rounds: int, cipher_family: str) -> Dict[str, Any]:
        """Compute enriched features using family-specific knowledge"""
        
        # Get round function size based on cipher family
        round_function_size = self._get_round_function_size(ops_summary, cipher_family)
        
        # Operations that would execute per round
        operations_per_round = round_function_size
        
        # Complexity ratio: operations per round relative to total rounds
        complexity_ratio = operations_per_round / max(rounds, 1)
        
        # Estimated total operations (round function × rounds)
        estimated_total_operations = operations_per_round * rounds if rounds > 0 else operations_per_round
        
        # Key schedule operations (family-specific estimates)
        key_schedule_ops = self._get_key_schedule_operations(cipher_family, round_function_size)
        
        # Encryption vs decryption operations
        enc_ops, dec_ops = self._get_enc_dec_operations(cipher_family, round_function_size)
        
        return {
            "round_function_size": round_function_size,
            "operations_per_round": operations_per_round,
            "complexity_ratio": round(complexity_ratio, 4),
            "estimated_total_operations": estimated_total_operations,
            "key_schedule_operations": key_schedule_ops,
            "encryption_operations": enc_ops,
            "decryption_operations": dec_ops
        }

    def _get_round_function_size(self, ops_summary: Dict, cipher_family: str) -> int:
        """Get round function size based on cipher family and operation counts"""
        # Family-specific round function size estimates
        if "feistel" in cipher_family:
            # Simon: F_function (3 ROTL + 1 AND + 1 XOR) + round logic ≈ 6-8
            core_ops = ['xor_count', 'rotl_count', 'and_count']
            return min(8, max(6, sum(ops_summary.get(op, 0) for op in core_ops)))
            
        elif "arx" in cipher_family:
            # Speck: 1 ROTR + 1 ADD + 2 XOR + 1 ROTL ≈ 5-6
            core_ops = ['xor_count', 'rotl_count', 'rotr_count', 'add_count']
            return min(6, max(5, sum(ops_summary.get(op, 0) for op in core_ops)))
            
        elif "spn" in cipher_family:
            # PRESENT: S-box + permutation + key mixing ≈ 2-3
            core_ops = ['sbox_count', 'perm_count', 'xor_count']
            return min(3, max(2, sum(ops_summary.get(op, 0) for op in core_ops)))
            
        else:
            # Default conservative estimate
            return sum(ops_summary.values()) // 2

    def _get_key_schedule_operations(self, cipher_family: str, round_function_size: int) -> int:
        """Get key schedule operations based on cipher family"""
        # Family-specific key schedule complexity estimates
        if "feistel" in cipher_family:
            return max(1, round_function_size // 2)  # Simon key schedule
        elif "arx" in cipher_family:
            return max(1, round_function_size // 2)  # Speck key schedule  
        elif "spn" in cipher_family:
            return 1  # PRESENT key schedule is simpler
        else:
            return 0

    def _get_enc_dec_operations(self, cipher_family: str, round_function_size: int) -> Tuple[int, int]:
        """Get encryption vs decryption operations based on cipher family"""
        if "feistel" in cipher_family:
            return round_function_size, round_function_size  # Simon: symmetric
        elif "arx" in cipher_family:
            return round_function_size, round_function_size  # Speck: symmetric
        elif "spn" in cipher_family:
            return round_function_size, round_function_size  # PRESENT: symmetric
        else:
            return round_function_size, round_function_size
    #######

    
    def _extract_round_function_operations(self, ast_data: Dict) -> Dict[str, int]:
        """Actually parse and count operations in round functions only"""
        nodes = ast_data.get("nodes", [])
        edges = ast_data.get("edges", [])
        
        # Find round function nodes
        round_function_nodes = self._find_round_functions(nodes)

        print('round_function_nodes\n\n\n', round_function_nodes)
        
        # Count operations within round functions only
        round_function_ops = self._count_operations_in_round_functions(round_function_nodes, nodes, edges)
        
        return round_function_ops
    
    def _find_round_functions(self, nodes: List[Dict]) -> List[int]:
        """Find nodes that represent round functions"""
        round_function_indicators = [
            'simon_round', 'speck_enc_round', 'speck_dec_round', 
            'present_round', 'encrypt_iterate', 'F_function'
        ]
        
        round_function_ids = []
        for node in nodes:
            label = str(node.get('label', '')).lower()
            if any(indicator in label for indicator in round_function_indicators):
                round_function_ids.append(node['id'])
        
        return round_function_ids
    
    def _count_operations_in_round_functions(self, round_function_ids: List[int], all_nodes: List[Dict], all_edges: List[Dict]) -> Dict[str, int]:
        """Count operations that are children of round functions"""
        # Find all nodes that are contained within round functions
        round_function_children = set()
        
        for edge in all_edges:
            if edge.get('type') == 'contains' and edge['source'] in round_function_ids:
                round_function_children.add(edge['target'])
        
        # Also include nodes reachable from round functions
        visited = set(round_function_children)
        queue = list(round_function_children)
        
        while queue:
            current_id = queue.pop(0)
            for edge in all_edges:
                if edge['source'] == current_id and edge['target'] not in visited:
                    visited.add(edge['target'])
                    queue.append(edge['target'])
        
        # Count operations in these nodes
        op_counts = {
            'xor_count': 0, 'rotl_count': 0, 'rotr_count': 0,
            'add_count': 0, 'sub_count': 0, 'and_count': 0,
            'sbox_count': 0, 'perm_count': 0
        }
        
        op_mapping = {
            'XOR': 'xor_count',
            'ROTL': 'rotl_count', 
            'ROTR': 'rotr_count',
            'ADD': 'add_count',
            'SUB': 'sub_count', 
            'AND': 'and_count',
            'F_FUNCTION': 'sbox_count',  # For Simon
            'LIST_REV': 'perm_count'     # For permutations
        }
        
        for node_id in visited:
            node = next((n for n in all_nodes if n['id'] == node_id), None)
            if node and node.get('type') == 'op':
                op_label = node.get('label', '')
                if op_label in op_mapping:
                    op_counts[op_mapping[op_label]] += 1
                elif 'sbox' in op_label.lower():
                    op_counts['sbox_count'] += 1
                elif 'perm' in op_label.lower() or 'player' in op_label.lower():
                    op_counts['perm_count'] += 1
        
        return op_counts
    
    def _estimate_key_schedule_operations(self, ast_data: Dict) -> int:
        """Estimate key schedule operation complexity from AST"""
        nodes = ast_data.get("nodes", [])
        
        # Look for key schedule functions
        key_schedule_indicators = ['key_schedule', 'key_expansion', 'key_gen', 'gen_key_schedule']
        has_key_schedule = any(
            any(indicator in str(node.get('label', '')).lower() for indicator in key_schedule_indicators)
            for node in nodes
        )
        
        if has_key_schedule:
            # Conservative estimate: key schedule is typically 1/3 to 1/2 of round function complexity
            round_ops = self._extract_round_function_operations(ast_data)
            round_complexity = sum(round_ops.values())
            return max(1, round_complexity // 3)
        return 0
    
    def _estimate_enc_dec_operations(self, round_function_ops: Dict, structure: Dict) -> Tuple[int, int]:
        """Estimate encryption vs decryption operation counts"""
        round_complexity = sum(round_function_ops.values())
        
        # Simple heuristic based on cipher structure
        has_dec_round = structure.get('has_dec_round', 0)
        if has_dec_round:
            # If separate decryption round exists, assume similar complexity
            return round_complexity, round_complexity
        else:
            # If no separate decryption, assume decryption is inverse (similar complexity)
            return round_complexity, round_complexity

    def _extract_round_complexity(self, pdv: Dict, cipher_family: str) -> int:
        """Extract round complexity from the correct family-specific structure"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("f_function_complexity", 0)
        elif "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("enc_round_complexity", 0)
        elif "spn" in cipher_family:
            return pdv.get("spn_structure", {}).get("round_complexity", 0)
        return 0
    
    def _extract_rotation_diversity(self, pdv: Dict, cipher_family: str) -> int:
        """Extract rotation diversity from the correct family-specific structure"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("rotation_diversity", 0)
        elif "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("rotation_diversity", 0)
        return 0
 
    def _extract_max_rotation(self, pdv: Dict, cipher_family: str) -> int:
        """Extract maximum rotation amount from the correct family-specific structure"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("max_rotation_amount", 0)
        elif "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("max_rotation_amount", 0)
        return 0
    
    def _extract_has_round_function(self, pdv: Dict, cipher_family: str) -> int:
        """Extract has_round_function from the correct family-specific structure"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("has_round_function", 0)
        elif "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("has_enc_round", 0)
        elif "spn" in cipher_family:
            return pdv.get("spn_structure", {}).get("has_round_function", 0)
        return 0
    
    def _extract_has_f_function(self, pdv: Dict, cipher_family: str) -> int:
        """Extract has_f_function - primarily for Feistel ciphers"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("has_f_function", 0)
        return 0
    
    def _extract_has_enc_round(self, pdv: Dict, cipher_family: str) -> int:
        """Extract has_enc_round - primarily for ARX ciphers"""
        if "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("has_enc_round", 0)
        if "feistel" in cipher_family or "spn" in cipher_family:
            return 1
        return 0
    
    def _extract_has_dec_round(self, pdv: Dict, cipher_family: str) -> int:
        """Extract has_dec_round - primarily for ARX ciphers"""
        if "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("has_dec_round", 0)
        if "feistel" in cipher_family or "spn" in cipher_family:
            return 1
        return 0
    
    def _extract_has_key_schedule(self, pdv: Dict, cipher_family: str) -> int:
        """Extract has_key_schedule from the correct family-specific structure"""
        if "feistel" in cipher_family:
            return pdv.get("feistel_structure", {}).get("has_key_schedule", 0)
        elif "arx" in cipher_family:
            return pdv.get("arx_structure", {}).get("has_key_schedule", 0)
        elif "spn" in cipher_family:
            return pdv.get("spn_structure", {}).get("has_key_schedule", 0)
        return 0
    
    def _extract_uses_shift_params(self, pdv: Dict, cipher_family: str) -> int:
        """Extract uses_shift_params - primarily for ARX ciphers like Speck"""
        if "arx" in cipher_family:
            return 1 if pdv.get("shift_parameters", {}).get("shift_params_defined", False) else 0
        return 0
        

In [5]:
import os
import json
import random
import copy
from typing import Dict, Any, List

# Your existing utility functions are GOOD - keep them as-is
def load_json(file_path: str) -> Dict[str, Any]:
    """Load JSON file with error handling."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json(data: Dict[str, Any], file_path: str):
    """Save JSON file with proper formatting."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2)

def safe_truncate(text: str, max_length: int) -> str:
    """Safely truncate text to maximum length."""
    if len(text) <= max_length:
        return text
    return text[:max_length-3] + "..."

def recompute_op_counts_from_nodes_and_edges(ast: Dict[str, Any]):
    """Recompute operation counts from AST nodes after augmentation."""
    nodes = ast.get("nodes", [])
    
    # Count operations from nodes
    op_counts = {
        "xor_count": 0, "rotl_count": 0, "rotr_count": 0,
        "add_count": 0, "sub_count": 0, "and_count": 0,
        "sbox_count": 0, "perm_count": 0
    }
    
    for node in nodes:
        if node.get("type") == "op":
            label = node.get("label", "").upper()
            if label == "XOR":
                op_counts["xor_count"] += 1
            elif label == "ROTL":
                op_counts["rotl_count"] += 1
            elif label == "ROTR":
                op_counts["rotr_count"] += 1
            elif label == "ADD":
                op_counts["add_count"] += 1
            elif label == "SUB":
                op_counts["sub_count"] += 1
            elif label == "AND":
                op_counts["and_count"] += 1
            elif "SBOX" in label:
                op_counts["sbox_count"] += 1
            elif "PERM" in label:
                op_counts["perm_count"] += 1
    
    # Update the PDV
    if "pdv" in ast and "ops_summary" in ast["pdv"]:
        ast["pdv"]["ops_summary"].update(op_counts)

def recompute_graph_stats(ast: Dict[str, Any]):
    """Recompute graph statistics after augmentation."""
    nodes = ast.get("nodes", [])
    edges = ast.get("edges", [])
    functions = ast.get("functions", [])
    
    # Update unified PDV if it exists
    if "unified_pdv" in ast:
        ast["unified_pdv"]["ast_node_count"] = len(nodes)
        ast["unified_pdv"]["ast_edge_count"] = len(edges)
        ast["unified_pdv"]["function_count"] = len(functions)

# FIXED: Safer implementation that works with your actual AST structure
def safe_representation_variation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
    Change representation without changing semantics - SAFE VERSION
    Only modifies metadata, not actual operations
    """
    new_ast = copy.deepcopy(ast)
    
    # Add representation metadata to nodes without changing operations
    for node in new_ast.get("nodes", []):
        if node.get("type") in ["op", "function"]:
            # Initialize features if not present
            if "features" not in node:
                node["features"] = {}
            
            # Add representation variant as metadata only
            variants = ["direct", "decomposed", "optimized", "canonical"]
            node["features"]["representation_variant"] = random.choice(variants)
    
    return new_ast

# FIXED: Metadata-only versions for cipher-specific strategies
def safe_present_sbox_representation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
    Vary S-box representation in PRESENT - METADATA ONLY
    """
    if ast.get("pdv", {}).get("cipher_family") != "PRESENT":
        return ast
    
    new_ast = copy.deepcopy(ast)
    
    sbox_representations = ["lookup_table", "computed_form", "bit_sliced", "algebraic_form"]
    
    for node in new_ast.get("nodes", []):
        if (node.get("type") == "function" and 
            any(sbox_keyword in node.get("label", "").lower() 
                for sbox_keyword in ["sbox", "present_sbox"])):
            
            if "features" not in node:
                node["features"] = {}
            node["features"]["sbox_representation"] = random.choice(sbox_representations)
    
    return new_ast

def safe_present_permutation_variation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
    Vary permutation layer implementation - METADATA ONLY
    """
    if ast.get("pdv", {}).get("cipher_family") != "PRESENT":
        return ast
    
    new_ast = copy.deepcopy(ast)
    
    permutation_variants = ["bitwise_mapping", "matrix_rotation", "word_operations", "parallel_blocks"]
    
    for node in new_ast.get("nodes", []):
        if (node.get("type") == "function" and 
            any(perm_keyword in node.get("label", "").lower() 
                for perm_keyword in ["p_layer", "permutation"])):
            
            if "features" not in node:
                node["features"] = {}
            node["features"]["permutation_implementation"] = random.choice(permutation_variants)
    
    return new_ast

def safe_present_round_structure_variation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
    Vary PRESENT round function implementation - METADATA ONLY
    """
    if ast.get("pdv", {}).get("cipher_family") != "PRESENT":
        return ast
    
    new_ast = copy.deepcopy(ast)
    
    round_variants = ["sequential", "integrated", "parallel", "pipelined"]
    
    for node in new_ast.get("nodes", []):
        if (node.get("type") == "function" and 
            "present_round" in node.get("label", "").lower()):
            
            if "features" not in node:
                node["features"] = {}
            node["features"]["round_implementation"] = random.choice(round_variants)
    
    return new_ast

def safe_key_schedule_representation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
    Vary key schedule implementation - METADATA ONLY
    """
    new_ast = copy.deepcopy(ast)
    
    variants = ["recursive", "iterative", "unrolled", "table_based"]
    
    for node in new_ast.get("nodes", []):
        if (node.get("type") == "function" and 
            "key" in node.get("label", "").lower() and 
            "schedule" in node.get("label", "").lower()):
            
            if "features" not in node:
                node["features"] = {}
            node["features"]["implementation_variant"] = random.choice(variants)
    
    return new_ast

In [6]:
"""
Enhanced safe augmentations with improved safety and new strategies
"""

import random
import copy
import re
from typing import Dict, Any, List, Set

# -----------------------
# IMPROVED CORE AUGMENTATIONS
# -----------------------

def safe_commutative_operation_reordering(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Reorder commutative operations with dependency checking
    """
    new_ast = copy.deepcopy(ast)
    
    for func_name, op_sequence in new_ast.get("op_sequences", {}).items():
        if len(op_sequence) < 2:
            continue
            
        new_sequence = []
        i = 0
        
        while i < len(op_sequence):
            if op_sequence[i] in {"XOR", "ADD"}:
                # Find consecutive commutative operations
                j = i
                while j < len(op_sequence) and op_sequence[j] in {"XOR", "ADD"}:
                    j += 1
                
                # Only shuffle if no obvious dependencies in sequence
                block = op_sequence[i:j]
                if len(block) > 1 and _is_independent_block(block, func_name, new_ast):
                    random.shuffle(block)
                new_sequence.extend(block)
                i = j
            else:
                new_sequence.append(op_sequence[i])
                i += 1
        
        new_ast["op_sequences"][func_name] = new_sequence
    
    return new_ast

def _is_independent_block(block: List[str], func_name: str, ast: Dict[str, Any]) -> bool:
    """
    Check if operations in block are likely independent
    Conservative approach: assume dependent if complex structure
    """
    # For now, use simple heuristic - allow shuffling for short blocks
    return len(block) <= 4

def safe_function_renaming(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Rename helper functions with crypto-core protection
    """
    new_ast = copy.deepcopy(ast)
    
    # Core cryptographic functions to NEVER rename
    CRYPTO_CORE_FUNCTIONS = {
        "encrypt", "decrypt", "round", "key", "schedule", "sbox", "p_layer",
        "F_function", "simon_round", "speck_enc_round", "speck_dec_round",
        "present_round", "present_sbox", "generate_key_schedule"
    }
    
    # Safe helper prefixes/suffixes
    SAFE_HELPERS = {"helper", "aux", "temp", "compute", "calculate", "process"}
    
    rename_map = {}
    nodes = new_ast.get("nodes", [])
    
    # Identify safe functions to rename
    for node in nodes:
        if node.get("type") == "function":
            label = node.get("label", "").lower()
            
            # Check if this is a crypto core function
            is_core_function = any(core in label for core in CRYPTO_CORE_FUNCTIONS)
            is_safe_helper = any(helper in label for helper in SAFE_HELPERS)
            
            if is_safe_helper and not is_core_function:
                old_name = node["label"]
                new_name = f"{old_name}_v{random.randint(1, 3)}"
                rename_map[old_name] = new_name
                node["label"] = new_name
    
    # Update references
    if rename_map:
        for node in nodes:
            if node.get("type") == "function_call" and node["label"] in rename_map:
                node["label"] = rename_map[node["label"]]
        
        # Update functions list
        for func in new_ast.get("functions", []):
            if func["name"] in rename_map:
                func["name"] = rename_map[func["name"]]
        
        # Update op_sequences keys
        if "op_sequences" in new_ast:
            new_ops = {}
            for k, v in new_ast["op_sequences"].items():
                new_ops[rename_map.get(k, k)] = v
            new_ast["op_sequences"] = new_ops
    
    return new_ast

def safe_bit_operation_commutativity(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Safe bit operation reordering (removed directional shifts)
    """
    new_ast = copy.deepcopy(ast)
    
    # ONLY truly commutative bit operations
    commutative_bit_ops = {
        "bitwise_and", "bitwise_or", "bitwise_xor"
    }
    
    for func_name, op_sequence in new_ast.get("op_sequences", {}).items():
        if len(op_sequence) < 2:
            continue
            
        new_sequence = []
        i = 0
        
        while i < len(op_sequence):
            if op_sequence[i] in commutative_bit_ops:
                # Find consecutive commutative operations
                j = i
                while j < len(op_sequence) and op_sequence[j] in commutative_bit_ops:
                    j += 1
                
                # Shuffle this commutative block
                block = op_sequence[i:j]
                if len(block) > 1:
                    random.shuffle(block)
                new_sequence.extend(block)
                i = j
            else:
                new_sequence.append(op_sequence[i])
                i += 1
        
        new_ast["op_sequences"][func_name] = new_sequence
    
    return new_ast

# -----------------------
# NEW AUGMENTATION STRATEGIES
# -----------------------

def safe_temporary_variable_renaming(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Rename temporary/local variables without changing logic
    """
    new_ast = copy.deepcopy(ast)
    
    # Common temporary variable patterns
    temp_patterns = {"tmp", "temp", "x", "y", "z", "a", "b", "c", "var", "val"}
    
    rename_map = {}
    nodes = new_ast.get("nodes", [])
    
    # Identify temporary variables to rename
    for node in nodes:
        if node.get("type") == "var":
            var_name = node.get("label", "")
            
            # Only rename if it looks like a temporary
            if (var_name.lower() in temp_patterns or 
                re.match(r'^[a-z]$', var_name) or  # Single letter
                re.match(r'^tmp\d*$', var_name.lower())):  # tmp123
                
                new_name = f"{var_name}_{random.randint(1, 9)}"
                rename_map[var_name] = new_name
                node["label"] = new_name
    
    # Update variable references in function bodies
    if rename_map:
        for node in nodes:
            if node.get("type") == "function" and "body_text" in node.get("features", {}):
                body = node["features"]["body_text"]
                for old_name, new_name in rename_map.items():
                    # Use word boundaries to avoid partial replacements
                    body = re.sub(r'\b' + re.escape(old_name) + r'\b', new_name, body)
                node["features"]["body_text"] = safe_truncate(body, 300)
    
    return new_ast

def safe_comment_whitespace_injection(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Add synthetic comments and whitespace variations
    """
    new_ast = copy.deepcopy(ast)
    
    comment_templates = [
        "(* Helper computation *)",
        "(* Temporary variable *)", 
        "(* Cryptographic operation *)",
        "(* Bit manipulation *)",
        "(* Round function component *)"
    ]
    
    for node in new_ast.get("nodes", []):
        if node.get("type") == "function" and "body_text" in node.get("features", {}):
            body = node["features"]["body_text"]
            
            # Occasionally add a comment at the beginning
            if random.random() < 0.3:
                comment = random.choice(comment_templates)
                body = f"{comment}\n{body}"
            
            # Add some whitespace variations
            if random.random() < 0.4:
                # Randomly add/remove some newlines
                lines = body.split('\n')
                if len(lines) > 2:
                    # Occasionally add an extra newline
                    if random.random() < 0.3:
                        insert_pos = random.randint(1, len(lines)-1)
                        lines.insert(insert_pos, "")
                    body = '\n'.join(lines)
            
            node["features"]["body_text"] = safe_truncate(body, 300)
    
    return new_ast

def safe_annotation_augmentation(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Add synthetic Isabelle annotations
    """
    new_ast = copy.deepcopy(ast)
    
    annotations = [
        "(*@ verified *)",
        "(*@ inline *)", 
        "(*@ pure *)",
        "(*@ preserves_crypto *)",
        "(*@ preserves_semantics *)"
    ]
    
    for node in new_ast.get("nodes", []):
        if node.get("type") == "function" and "body_text" in node.get("features", {}):
            body = node["features"]["body_text"]
            
            # Add annotation with some probability
            if random.random() < 0.25:
                annotation = random.choice(annotations)
                body = f"{annotation}\n{body}"
                node["features"]["body_text"] = safe_truncate(body, 300)
            
            # Add annotation metadata to features
            features = node.get("features", {})
            if "annotations" not in features:
                features["annotations"] = []
            if random.random() < 0.2:
                features["annotations"].append(random.choice(["verified", "pure", "inline"]))
            node["features"] = features
    
    return new_ast

def safe_function_inlining(ast: Dict[str, Any]) -> Dict[str, Any]:
    """
     Replace simple helper function calls with metadata
    (Doesn't actually inline, just marks for representation)
    """
    new_ast = copy.deepcopy(ast)
    
    for node in new_ast.get("nodes", []):
        if node.get("type") == "function_call":
            # Mark some calls as "inlinable" in metadata
            func_name = node.get("label", "")
            
            # Simple heuristics for inlinable functions
            is_simple_helper = any(pattern in func_name.lower() 
                                 for pattern in ["helper", "compute", "calculate", "get"])
            
            if is_simple_helper and random.random() < 0.3:
                features = node.get("features", {})
                features["inlinable"] = True
                features["inline_variant"] = random.choice(["direct", "expanded", "optimized"])
                node["features"] = features
    
    return new_ast

# -----------------------
# ENHANCED STRATEGY LISTS
# -----------------------

# Core strategies (high impact, very safe)
CORE_AUG_STRATEGIES = [
    safe_commutative_operation_reordering,
    safe_function_renaming,
    safe_temporary_variable_renaming,
    safe_representation_variation,
]

# Textual variation strategies (safe, adds diversity)
TEXTUAL_AUG_STRATEGIES = [
    safe_comment_whitespace_injection,
    safe_annotation_augmentation,
]

# Cipher-specific strategies
PRESENT_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_bit_operation_commutativity,
    safe_present_sbox_representation,
    safe_present_permutation_variation,
    safe_present_round_structure_variation,
] + TEXTUAL_AUG_STRATEGIES

SPECK_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_key_schedule_representation,
] + TEXTUAL_AUG_STRATEGIES

SIMON_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_key_schedule_representation, 
] + TEXTUAL_AUG_STRATEGIES

AUGMENTATION_STRATEGIES = {
    "Simon": SIMON_SPECIFIC_STRATEGIES,
    "Speck": SPECK_SPECIFIC_STRATEGIES,
    "PRESENT": PRESENT_SPECIFIC_STRATEGIES
}

# Strategy weights for sampling (higher = more frequent)
STRATEGY_WEIGHTS = {
    safe_commutative_operation_reordering: 2.0,  # High impact
    safe_function_renaming: 1.5,                 # Medium impact  
    safe_temporary_variable_renaming: 1.2,       # Medium impact
    safe_representation_variation: 1.0,          # Medium impact
    safe_comment_whitespace_injection: 0.8,      # Lower impact
    safe_annotation_augmentation: 0.7,           # Lower impact
    safe_function_inlining: 0.6,                 # Lower impact
}



def weighted_strategy_sampling(strategies: List, num_strategies: int) -> List:
    """
    Sample strategies based on weights for more intelligent augmentation
    """
    if not strategies:
        return []
    
    # Get weights for available strategies
    available_weights = [STRATEGY_WEIGHTS.get(s, 1.0) for s in strategies]
    
    # Normalize weights
    total_weight = sum(available_weights)
    if total_weight == 0:
        return random.sample(strategies, min(num_strategies, len(strategies)))
    
    normalized_weights = [w / total_weight for w in available_weights]
    
    # Sample without replacement using weights
    selected = []
    remaining_strategies = strategies.copy()
    remaining_weights = normalized_weights.copy()
    
    for _ in range(min(num_strategies, len(strategies))):
        if not remaining_strategies:
            break
            
        # Weighted random choice
        chosen_idx = random.choices(
            range(len(remaining_strategies)), 
            weights=remaining_weights
        )[0]
        
        selected.append(remaining_strategies[chosen_idx])
        
        # Remove selected strategy
        remaining_strategies.pop(chosen_idx)
        remaining_weights.pop(chosen_idx)
        
        # Re-normalize weights
        total_remaining = sum(remaining_weights)
        if total_remaining > 0:
            remaining_weights = [w / total_remaining for w in remaining_weights]
    
    return selected



    

## additional structural and metadata augmentation techniques

In [9]:
# (Add these to your augmentation script)
import re

def safe_feature_noise_injection(ast: Dict[str, Any], noise_level: float = 0.05) -> Dict[str, Any]:
    """
    Add small Gaussian noise to continuous numerical features.
    """
    new_ast = copy.deepcopy(ast)
    
    for node in new_ast.get("nodes", []):
        # Add noise to the root numerical features
        for key in ["crypto_strength", "diffusion_power", "nonlinearity"]:
            if key in node and isinstance(node[key], (int, float)):
                if node[key] > 0: # Only add noise to non-zero features
                    noise = random.gauss(0, noise_level)
                    node[key] = max(0, node[key] * (1 + noise)) # Ensure non-negative

        # Add noise to features in the 'features' sub-dict (if you use it)
        if "features" in node and isinstance(node["features"], dict):
            for key, value in node["features"].items():
                if isinstance(value, (int, float)) and value > 0 and random.random() < 0.2:
                    noise = random.gauss(0, noise_level)
                    node["features"][key] = max(0, value * (1 + noise))
    
    return new_ast

def safe_node_dropout_cryptographic(ast: Dict[str, Any], p: float = 0.15) -> Dict[str, Any]:
    """
    Enhanced node dropout that understands cryptographic semantics
    (Deepseek's excellent version)
    """
    new_ast = copy.deepcopy(ast)
    nodes = new_ast.get("nodes", [])
    
    if len(nodes) < 15:  # Don't augment very small graphs
        return new_ast

    # CRYPTOGRAPHIC SEMANTICS: Define what can NEVER be dropped
    NEVER_DROP_ROLES = {
        "feistel_f_function", "sbox_substitution", "permutation_layer",
        "modular_addition", "nonlinear_mixing"
    }
    
    NEVER_DROP_LABELS = {
        "round", "encrypt", "decrypt", "key_schedule", "F_function",
        "sbox", "p_layer", "present_round", "simon_round", "speck_enc_round"
    }

    # Classify nodes by cryptographic importance
    critical_node_ids = set()
    droppable_nodes = []

    for node in nodes:
        is_critical = False
        if node.get("crypto_role") in NEVER_DROP_ROLES:
            is_critical = True
        for label in NEVER_DROP_LABELS:
            if label in node.get("label", "").lower():
                is_critical = True
        
        if is_critical:
            critical_node_ids.add(node["id"])
        else:
            droppable_nodes.append(node)

    # Apply dropout with cryptographic awareness
    if droppable_nodes:
        num_to_drop = max(1, int(len(droppable_nodes) * p))
        nodes_to_drop = random.sample(droppable_nodes, num_to_drop)
        ids_to_drop = {node["id"] for node in nodes_to_drop}
        
        new_ast["nodes"] = [n for n in nodes if n["id"] not in ids_to_drop]
        new_ast["edges"] = [
            e for e in new_ast.get("edges", [])
            if e["source"] not in ids_to_drop and e["target"] not in ids_to_drop
        ]
        
        # Track what was dropped for analysis
        if "pdv" not in new_ast: new_ast["pdv"] = {}
        if "augmentation_metadata" not in new_ast["pdv"]: new_ast["pdv"]["augmentation_metadata"] = {}
        
        aug_meta = new_ast["pdv"]["augmentation_metadata"]
        aug_meta["node_dropout_count"] = num_to_drop
        aug_meta["dropped_node_types"] = list(set(n.get("type") for n in nodes_to_drop))

    return new_ast

def safe_edge_perturbation_cryptographic(ast: Dict[str, Any], p_drop: float = 0.08, p_add: float = 0.05) -> Dict[str, Any]:
    """
    Edge perturbation that respects cryptographic data flow
    (Deepseek's excellent version)
    """
    new_ast = copy.deepcopy(ast)
    nodes = new_ast.get("nodes", [])
    edges = new_ast.get("edges", [])
    
    if not edges:
        return new_ast

    # NEVER drop these critical edge types
    CRITICAL_EDGE_TYPES = {"contains", "func", "amount"}  # Structural edges
    
    # Safe to drop these edge types
    DROPPABLE_EDGE_TYPES = {"arg", "child", "left", "right", "binding", "body"}
    
    # Separate edges by importance
    critical_edges = [e for e in edges if e.get("type") in CRITICAL_EDGE_TYPES]
    droppable_edges = [e for e in edges if e.get("type") in DROPPABLE_EDGE_TYPES]
    
    # Edge dropping (only from droppable set)
    num_edges_to_drop = max(1, int(len(droppable_edges) * p_drop))
    remaining_edges = critical_edges
    if droppable_edges and num_edges_to_drop > 0:
        edges_to_keep = random.sample(droppable_edges, len(droppable_edges) - num_edges_to_drop)
        remaining_edges.extend(edges_to_keep)
    else:
        remaining_edges.extend(droppable_edges)
    
    # Edge addition (create semantically plausible new connections)
    new_edges = []
    num_edges_to_add = max(1, int(len(edges) * p_add))
    
    source_candidates = [n["id"] for n in nodes if n.get("type") in ["op", "function", "var", "literal"]]
    target_candidates = [n["id"] for n in nodes if n.get("type") in ["op", "function"]]
    
    for _ in range(num_edges_to_add):
        if source_candidates and target_candidates:
            source_id = random.choice(source_candidates)
            target_id = random.choice(target_candidates)
            
            if (source_id != target_id and
                not any(e["source"] == source_id and e["target"] == target_id 
                        for e in remaining_edges + new_edges)):
                
                new_edge = {
                    "source": source_id,
                    "target": target_id,
                    "type": "arg",  # A plausible, droppable type
                    "features": {"synthetic": True, "augmentation": "edge_add"}
                }
                new_edges.append(new_edge)
    
    new_ast["edges"] = remaining_edges + new_edges
    
    # Update metadata
    if "pdv" not in new_ast: new_ast["pdv"] = {}
    if "augmentation_metadata" not in new_ast["pdv"]: new_ast["pdv"]["augmentation_metadata"] = {}
    
    aug_meta = new_ast["pdv"]["augmentation_metadata"]
    aug_meta["edge_perturb_drop"] = num_edges_to_drop
    aug_meta["edge_perturb_add"] = len(new_edges)

    return new_ast

In [10]:
# In your augmentation script:

# --- NEW STRATEGY LISTS ---
# Core strategies (high impact, very safe)
CORE_AUG_STRATEGIES = [
    safe_node_dropout_cryptographic,          #
    safe_edge_perturbation_cryptographic,     # 
    safe_commutative_operation_reordering,
    safe_function_renaming,
    safe_temporary_variable_renaming,
    safe_representation_variation,
]

# Textual/Noise variation strategies (safe, adds diversity)
TEXTUAL_AUG_STRATEGIES = [
    safe_feature_noise_injection,             # 
    safe_comment_whitespace_injection,
    safe_annotation_augmentation,
]

# Cipher-specific strategies
PRESENT_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_bit_operation_commutativity,
    safe_present_sbox_representation,
    safe_present_permutation_variation,
    safe_present_round_structure_variation,
] + TEXTUAL_AUG_STRATEGIES

SPECK_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_key_schedule_representation,
] + TEXTUAL_AUG_STRATEGIES

SIMON_SPECIFIC_STRATEGIES = CORE_AUG_STRATEGIES + [
    safe_key_schedule_representation,  
] + TEXTUAL_AUG_STRATEGIES

AUGMENTATION_STRATEGIES = {
    "Simon": SIMON_SPECIFIC_STRATEGIES,
    "Speck": SPECK_SPECIFIC_STRATEGIES,
    "PRESENT": PRESENT_SPECIFIC_STRATEGIES
}

# --- NEW WEIGHTS ---
STRATEGY_WEIGHTS = {
    # Structural (Highest Impact)
    safe_node_dropout_cryptographic: 3.0,
    safe_edge_perturbation_cryptographic: 2.5,
    safe_commutative_operation_reordering: 2.0,
    
    # Metadata (Medium Impact)
    safe_function_renaming: 1.5,
    safe_temporary_variable_renaming: 1.2,
    safe_representation_variation: 1.0,

    # Noise/Textual (Regularization)
    safe_feature_noise_injection: 1.0,
    safe_comment_whitespace_injection: 0.5,
    safe_annotation_augmentation: 0.5,
}

In [None]:
# In your augmentation script:
# (Make sure to import all your augmentation functions)


# Configuration
AUGMENTED_DIR = "augmented_data_V6"
AUG_PER_FILE = 10 # Number of augmented variants per original file

SOURCE_DIR_ = 'output_ast_V5'
SIMON_INPUT = "output_ast_V5/Simon"
SPECK_INPUT = "output_ast_V5/Speck"
PRESENT_INPUT = "output_ast_V5/PRESENT"


def run_progressive_augmentation():
    """
    Creates multiple augmented datasets with increasing size
    based on the 23 original files.
    """
    
    # Your 23 original files are in "dataset-1"
    SOURCE_DIR = SOURCE_DIR_ 
    
    # We will create 3 new output directories
    OUTPUT_BASE_DIR = AUGMENTED_DIR

    CIPHERS = ["Simon", "Speck", "PRESENT"]
    
    # This is Deepseek's "Progressive Scaling" plan
    PROGRESSIVE_AUGMENTATION_PLANS = {
        'dataset-8x': {
            'augmentations_per_file': 7,  # 7 augs + 1 original = 8x
            'node_dropout_p': 0.1,
            'edge_perturb_p': 0.05,
            'noise_level': 0.02
        },
        'dataset-15x': {
            'augmentations_per_file': 14, # 14 augs + 1 original = 15x
            'node_dropout_p': 0.15,
            'edge_perturb_p': 0.08,
            'noise_level': 0.05
        },
        'dataset-30x': {
            'augmentations_per_file': 29, # 29 augs + 1 original = 30x
            'node_dropout_p': 0.2,
            'edge_perturb_p': 0.1,
            'noise_level': 0.05
        }
    }

    for plan_name, plan_config in PROGRESSIVE_AUGMENTATION_PLANS.items():
        print(f"\n--- Generating {plan_name} ---")
        num_augs = plan_config['augmentations_per_file']
        #output_dir_plan = os.path.join(OUTPUT_BASE_DIR, plan_name)
        output_dir_plan = os.path.join(OUTPUT_BASE_DIR, '')
        
        total_graphs = 0
        
        for cipher_name in CIPHERS:
            source_cipher_dir = os.path.join(SOURCE_DIR, cipher_name)
            output_cipher_dir = os.path.join(output_dir_plan, cipher_name)
            
            if not os.path.exists(source_cipher_dir):
                print(f"Warning: Source dir not found: {source_cipher_dir}")
                continue

            strategies = AUGMENTATION_STRATEGIES.get(cipher_name, CORE_AUG_STRATEGIES)
            json_files = [f for f in os.listdir(source_cipher_dir) 
                          if f.endswith(".json") and not f.startswith("_")]
            
            print(f"Processing {cipher_name}: {len(json_files)} original files...")

            for json_file in json_files:
                input_path = os.path.join(source_cipher_dir, json_file)
                try:
                    original_ast = load_json(input_path)
                    
                    # Create N augmented variants
                    # We pass the p_... values to the augmentation function
                    variants = create_enhanced_augmented_variants_progressive(
                        original_ast, 
                        num_augs, 
                        strategies,
                        plan_config
                    )
                    
                    base_name = json_file.replace(".json", "")
                    for i, variant in enumerate(variants):
                        if i == 0:
                            output_name = f"{base_name}_original.json"
                        else:
                            output_name = f"{base_name}_aug{i}.json"
                        
                        output_path = os.path.join(output_cipher_dir, output_name)
                        save_json(variant, output_path)
                        total_graphs += 1

                except Exception as e:
                    print(f"  Error processing {json_file}: {e}")
                    
        print(f"✅ Finished {plan_name}. Total graphs: {total_graphs}")


def create_enhanced_augmented_variants_progressive(original_ast: Dict[str, Any], 
                                                   num_variants: int, 
                                                   strategies: List,
                                                   plan_config: Dict) -> List[Dict[str, Any]]:
    """
    Modified creation function that accepts progressive scaling parameters.
    """
    variants = [copy.deepcopy(original_ast)] # Include original
    
    # Get parameters from the plan
    p_node_drop = plan_config['node_dropout_p']
    p_edge_drop = plan_config['edge_perturb_p']
    p_edge_add = plan_config['edge_perturb_p'] / 2 # Add fewer than we drop
    noise_lvl = plan_config['noise_level']

    for i in range(num_variants):
        augmented_ast = copy.deepcopy(original_ast)
        
        num_augmentations = random.randint(2, 4) # Apply 2-4 strategies per variant
        applied_strategies_fns = weighted_strategy_sampling(strategies, num_augmentations)
        
        applied_strategies_names = []

        for strategy_fn in applied_strategies_fns:
            try:
                # This is how you pass parameters to specific functions
                if strategy_fn == safe_node_dropout_cryptographic:
                    augmented_ast = strategy_fn(augmented_ast, p=p_node_drop)
                elif strategy_fn == safe_edge_perturbation_cryptographic:
                    augmented_ast = strategy_fn(augmented_ast, p_drop=p_edge_drop, p_add=p_edge_add)
                elif strategy_fn == safe_feature_noise_injection:
                    augmented_ast = strategy_fn(augmented_ast, noise_level=noise_lvl)
                else:
                    augmented_ast = strategy_fn(augmented_ast)
                
                applied_strategies_names.append(strategy_fn.__name__)
            except Exception as e:
                print(f"Warning: Strategy {strategy_fn.__name__} failed: {e}")
        
        # Recompute stats AFTER all augmentations
        recompute_op_counts_from_nodes_and_edges(augmented_ast)
        recompute_graph_stats(augmented_ast)
        
        # Update metadata
        if "pdv" not in augmented_ast: augmented_ast["pdv"] = {}
        augmented_ast["pdv"]["augmented"] = True
        augmented_ast["pdv"]["augmentation_metadata"] = {
            "strategies": applied_strategies_names,
            "variant_id": i + 1
        }
        
        # Preserve labels
        if "security_score" in original_ast:
            augmented_ast["security_score"] = original_ast["security_score"]
        if "security_label" in original_ast:
            augmented_ast["security_label"] = original_ast["security_label"]
        
        variants.append(augmented_ast)
    
    return variants


if __name__ == "__main__":
    # 1. Run the progressive augmentation to create your new datasets
    run_progressive_augmentation()
    
    pass

### SAMPLING THE DATASET:

Organized output to sampled_data_variant_based_balanced



In [None]:
# ----------------------------
# Data Sampling for Dataset Balance
# ----------------------------

import os

import json
import random
from collections import defaultdict
from typing import List, Dict, Any
import numpy as np

random.seed(42)


FLAG = 0

class CipherDataSampler:
    def __init__(self, input_base_dir: str, output_base_dir: str, delta_percent: float = 0.2):
        random.seed(42)
        np.random.seed(42)
        self.input_base_dir = input_base_dir
        self.output_base_dir = output_base_dir
        self.delta_percent = delta_percent
        self.label_mapping = {'low': 0, 'medium': 1, 'high': 2}
        self.reverse_label_mapping = {v: k for k, v in self.label_mapping.items()}

    # ----------------------------
    # Loading & Label Standardizing
    # ----------------------------
    def load_all_json_files(self) -> List[Dict[str, Any]]:
        all_files = []
        cipher_dirs = [
            d for d in os.listdir(self.input_base_dir)
            if os.path.isdir(os.path.join(self.input_base_dir, d))
        ]

        for cipher_dir in cipher_dirs:
            cipher_path = os.path.join(self.input_base_dir, cipher_dir)
            json_files = [f for f in os.listdir(cipher_path) if f.endswith(".json")]
            for json_file in json_files:
                file_path = os.path.join(cipher_path, json_file)
                try:
                    with open(file_path, "r", encoding="utf-8") as f:
                        data = json.load(f)
                        data["file_source"] = file_path
                        data["cipher"] = cipher_dir
                        # make sure cipher_variant exists (fallback to file name without .json)
                        if "cipher_variant" not in data:
                            base = os.path.splitext(json_file)[0]
                            data["cipher_variant"] = base
                        all_files.append(data)
                except Exception as e:
                    print(f"X -- Error loading {file_path}: {e}")
        print(f" OK -- Loaded {len(all_files)} total files from {len(cipher_dirs)} ciphers.")
        return all_files

    def standardize_labels(self, data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        for data in data_list:
            label = data.get("security_label", "low")
            if isinstance(label, str):
                data["standardized_label"] = self.label_mapping.get(label.lower(), 0)
            else:
                data["standardized_label"] = int(label)
        return data_list

    # ----------------------------
    # Build structure: cipher -> variant -> list(files)
    # Each file should indicate whether it's augmented
    # ----------------------------
    def build_cipher_variant_index(self, data_list: List[Dict[str, Any]]) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
        idx = defaultdict(lambda: defaultdict(list))
        for d in data_list:
            cipher = d.get("cipher", "unknown")
            variant = d.get("cipher_variant", "unknown_variant")
            idx[cipher][variant].append(d)
        return idx

    # ----------------------------
    # Utility: gather variant->label mapping (assume variant's label = label of original file)
    # If multiple originals exist (rare), choose the non-augmented one as canonical.
    # ----------------------------
    def variant_label(self, files: List[Dict[str, Any]]) -> int:
        # prefer a non-augmented file's label as the variant's label
        for f in files:
            if not f.get("pdv", {}).get("augmented", False):
                return f.get("standardized_label", 0)
        # otherwise return first file label
        return files[0].get("standardized_label", 0)

    # ----------------------------
    # Core algorithm: per-cipher sampling guided by samples_per_variant (applied to smallest-class variants)
    # ----------------------------
    def sample_cipher_for_samples_per_variant(self, variants: Dict[str, List[Dict[str, Any]]],
                                              samples_per_variant: int, FLAG=FLAG) -> List[Dict[str, Any]]:
        """
        variants: dict variant_name -> list[file dicts]
        Goal:
          - Determine variant counts per class
          - baseline = min_variant_count * samples_per_variant
          - baseline_with_delta = baseline + randint(0, int(baseline * delta_percent))
          - For each class:
              target_files = max(number_of_original_files_in_class, baseline_with_delta)
              select: all originals first, then add augmented files by variant up to samples_per_variant per variant
        Returns list of selected file dicts for the cipher.
        """
        # Organize variants by class label
        class_variants = defaultdict(list)  # label -> list of variant names
        variant_files = variants  # variant_name -> list of files

        # Build per-variant canonical label and per-variant split(originals, augmented)
        variant_meta = {}
        for vname, files in variant_files.items():
            label = self.variant_label(files)
            # separate originals vs augmented for this variant
            originals = [f for f in files if not f.get("pdv", {}).get("augmented", False)]
            augmented = [f for f in files if f.get("pdv", {}).get("augmented", False)]
            variant_meta[vname] = {"label": label, "originals": originals, "augmented": augmented}
            class_variants[label].append(vname)

        # If some class has zero variants, skip it
        classes_present = list(class_variants.keys())
        if not classes_present:
            return []
            
        FLAG += 1
        # min number of variants across present classes (we consider only classes that appear)
        min_variant_count = min(len(class_variants[c]) for c in classes_present)

        # baseline files = min_variant_count * samples_per_variant
        baseline = min_variant_count * samples_per_variant
        # add-only delta
        add_delta = random.randint(0, max(1, int(baseline * self.delta_percent)))
        add_delta2 = max(FLAG%2, add_delta)
        
        baseline_with_delta = baseline + add_delta

        #print('baseline_with_delta, baseline_with_delta, baseline + add_delta2, add_delta2, add_delta1')
        #print('baseline_with_delta', baseline_with_delta, baseline + add_delta2, add_delta2, add_delta)
        selected_files = []

        # For each class, compute target files and select
    


        
        # For each class, compute target files and select
        for label in classes_present:
            variant_names = class_variants[label]
            # number of originals available in class (one per variant typically)
            originals_list = []
            for v in variant_names:
                originals_list.extend(variant_meta[v]["originals"])

            num_original_files = len(originals_list)  # typically equals number of variants for this class

            # target_files for this class: at least baseline_with_delta, but never below num_original_files
            target_files = max(num_original_files, baseline_with_delta)

            # But cap target_files by total available files in this class
            total_available = num_original_files + sum(len(variant_meta[v]["augmented"]) for v in variant_names)
            target_files = min(target_files, total_available)

            # Step 1: include all originals (preserve originals)
            selected_class_files = list(originals_list)

            # Step 2: if we need more, add augmented files while respecting per-variant per-limit
            needed = target_files - len(selected_class_files)
            if needed > 0:
                # For fairness, iterate variants in round-robin order adding up to per-variant limit:
                # Each variant may contribute up to (samples_per_variant - selected_from_that_variant)
                per_variant_selected = {v: 0 for v in variant_names}
                # initialize per-variant selected count from originals (if original from that variant included)
                for v in variant_names:
                    # if variant had an original and it's included, count it
                    if variant_meta[v]["originals"]:
                        per_variant_selected[v] = min(1, samples_per_variant)  # original counted as one
                # gather pools of augmented per variant
                aug_pools = {v: list(variant_meta[v]["augmented"]) for v in variant_names}

                # round-robin across variants to avoid concentrating all added files in a single variant
                variant_cycle = variant_names.copy()
                vi = 0
                iteration_count = 0
                max_iterations = len(variant_cycle) * 100  # Safety limit
                
                # print(f"    [DEBUG] Starting round-robin selection, max_iterations={max_iterations}")
                
                # FIX: Create a list of variants that can still contribute
                active_variants = [v for v in variant_cycle 
                                  if per_variant_selected[v] < samples_per_variant and aug_pools[v]]

 
                while needed > 0 and active_variants and iteration_count < max_iterations:
                    v = active_variants[vi % len(active_variants)]
                    vi += 1
                    iteration_count += 1
                    
                    # This condition should always be true because of active_variants filtering
                    if per_variant_selected[v] < samples_per_variant and aug_pools[v]:
                        selected_class_files.append(aug_pools[v].pop(0))
                        per_variant_selected[v] += 1
                        needed -= 1
                        # print(f"    [DEBUG] Iteration {iteration_count}: Added from variant {v}, needed={needed}")
                    
                    # Update active_variants - remove variants that can no longer contribute
                    active_variants = [v for v in active_variants 
                                      if per_variant_selected[v] < samples_per_variant and aug_pools[v]]
                    
                    # Safety check - if we're stuck, break
                    if iteration_count >= max_iterations:
                        print(f"    [WARNING] Reached max iterations ({max_iterations}), breaking loop")
                        break
                
                # print(f"    [DEBUG] After round-robin: needed={needed}, active_variants={len(active_variants)}")
                
                # if still needed (rare), try adding remaining augmented from any variant ignoring per-variant cap
                if needed > 0:
                    leftover = []
                    for v in variant_names:
                        leftover.extend(aug_pools[v])
                    if leftover:
                        to_take = min(needed, len(leftover))
                        selected_class_files.extend(leftover[:to_take])
                        needed -= to_take

            # done for this class: add to global selection
            selected_files.extend(selected_class_files)

        # end for each class
        return selected_files

    # ----------------------------
    # Main multi-size generation (per-variant sampling but balancing via baseline)
    # ----------------------------
    def generate_datasets_for_variant_sizes(self, samples_per_variant_sizes: List[int], FLAG=FLAG):
        all_data = self.standardize_labels(self.load_all_json_files())
        # index: cipher -> variant -> list(files)
        cipher_variant_index = defaultdict(dict)
        raw_index = self.build_cipher_variant_index(all_data)
        for cipher, variants in raw_index.items():
            for vname, files in variants.items():
                cipher_variant_index[cipher][vname] = files

        for size in samples_per_variant_sizes:
            print(f"\n ### Generating dataset for samples_per_variant = {size} ###")
            output_dir = os.path.join(self.output_base_dir, f"samples_per_variant_{size}")
            os.makedirs(output_dir, exist_ok=True)
            # print("FOLDER CREATED")

            for cipher, variants in cipher_variant_index.items():
                FLAG +=1
                cipher_dir = os.path.join(output_dir, cipher)
                os.makedirs(cipher_dir, exist_ok=True)
                # print( "folder created :", cipher_dir ) 

                sampled_files = self.sample_cipher_for_samples_per_variant(variants, size )
                # print( "fALL SAMPLE FILES CREATED:" ) 
                

                # Save all samples for this cipher (no subfolders per class)
                for i, sample in enumerate(sampled_files):
                    if "file_source" in sample:
                        # print( "fALL SAMPLE FILES CREATED:" ) 
                        del sample["file_source"]
                    label_name = self.reverse_label_mapping[sample["standardized_label"]]
                    is_aug = "_aug" if sample.get("pdv", {}).get("augmented", False) else ""
                    variant = sample.get("cipher_variant", f"{cipher}_unknown")
                    filename = f"{variant}_{label_name}{is_aug}_{i:03d}.json"
                    filepath = os.path.join(cipher_dir, filename)
                    with open(filepath, "w", encoding="utf-8") as f:
                        json.dump(sample, f, indent=2)

                print(f"  - {cipher}: saved {len(sampled_files)} files → {cipher_dir}")

            print(f" OK -- Finished dataset for {size} samples per variant → {output_dir}")

# ----------------------------
# Usage 
# ----------------------------
if __name__ == "__main__":
    sampler = CipherDataSampler(
        input_base_dir="augmented_data_V6",
        output_base_dir="sampled_data_variant_based_balanced_V6",
        delta_percent=0.4
    )

    # Different sizes you want per variant (applied to smallest-class variants)
    samples_per_variant_sizes = [1, 2,3,4, 5, 6, 7, 8,9, 10 ]
    sampler.generate_datasets_for_variant_sizes(samples_per_variant_sizes)
    