In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
import numpy as np
from typing import List, Dict
from transformers import AutoModelForSequenceClassification

import os
import sys

sys.path.append(os.path.abspath('..'))

from src.v3.map_cwe import LABEL_NAMES, NUM_LABELS
from src.v3.dataset import VulnerabilityDataset
from src.v3.predictor import VulnerabilityPredictor

In [None]:
MAX_LEN = 512
STRIDE = 256
DEVICE = 'cuda'
MODEL_NAME = "microsoft/codebert-base"

In [None]:
# Rutas del archivo de código para probar y del modelo entrenado
OUTPUT_DIR = '../'
MODEL_PATH = os.path.join(OUTPUT_DIR, 'best_model.bin')

In [None]:
# Division del archivo código en funciones [code: string]
import re
from typing import List, Optional

class FunctionExtractor:
    """Extrae funciones de código fuente en diferentes lenguajes."""
    
    PATTERNS = {
        'python': [
            r'((?:async\s+)?def\s+\w+\s*\([^)]*\)\s*(?:->\s*[^:]+)?:\s*(?:"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\')?[\s\S]*?)(?=\n(?:async\s+)?def\s|\nclass\s|\Z)',
        ],
        'javascript': [
            r'(function\s+\w+\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
            r'((?:const|let|var)\s+\w+\s*=\s*function\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
            r'((?:const|let|var)\s+\w+\s*=\s*\([^)]*\)\s*=>\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
            r'(async\s+function\s+\w+\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'typescript': [
            r'(function\s+\w+\s*(?:<[^>]*>)?\s*\([^)]*\)\s*(?::\s*[^{]+)?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
            r'((?:const|let|var)\s+\w+\s*(?::\s*[^=]+)?\s*=\s*(?:async\s+)?function\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'java': [
            r'((?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?[\w<>\[\],\s]+\s+\w+\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'c': [
            r'((?:static\s+)?(?:inline\s+)?(?:const\s+)?[\w\s\*]+\s+\w+\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'cpp': [
            r'((?:virtual\s+)?(?:static\s+)?(?:inline\s+)?[\w\s\*<>:&]+\s+\w+\s*\([^)]*\)\s*(?:const)?\s*(?:override)?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'csharp': [
            r'((?:public|private|protected|internal)?\s*(?:static\s+)?(?:async\s+)?[\w<>\[\],\s]+\s+\w+\s*\([^)]*\)\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'go': [
            r'(func\s+(?:\([^)]+\)\s+)?\w+\s*\([^)]*\)\s*(?:[\w\*\[\]]+|\([^)]+\))?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'php': [
            r'((?:public|private|protected)?\s*(?:static\s+)?function\s+\w+\s*\([^)]*\)\s*(?::\s*\??\w+)?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
        'rust': [
            r'((?:pub\s+)?(?:async\s+)?fn\s+\w+\s*(?:<[^>]*>)?\s*\([^)]*\)\s*(?:->\s*[^{]+)?\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
        ],
    }
    
    EXTENSION_MAP = {
        '.py': 'python', '.js': 'javascript', '.jsx': 'javascript',
        '.ts': 'typescript', '.tsx': 'typescript', '.java': 'java',
        '.c': 'c', '.h': 'c', '.cpp': 'cpp', '.cc': 'cpp', '.hpp': 'cpp',
        '.cs': 'csharp', '.go': 'go', '.php': 'php', '.rs': 'rust',
    }

    def detect_language(self, filename: str) -> Optional[str]:
        """Detecta el lenguaje basándose en la extensión del archivo."""
        ext = os.path.splitext(filename)[1].lower()
        return self.EXTENSION_MAP.get(ext)
    
    def extract_functions(self, code: str, language: str) -> List[str]:
        """Extrae todas las funciones del código."""
        language = language.lower()
        
        if language not in self.PATTERNS:
            print(f"[WARN] Lenguaje '{language}' no soportado.")
            return []
        
        functions = []
        for pattern in self.PATTERNS[language]:
            matches = re.findall(pattern, code, re.MULTILINE | re.DOTALL)
            for match in matches:
                func = match.strip() if isinstance(match, str) else match[0].strip()
                if func and func not in functions:
                    functions.append(func)
        
        return functions
    
    def extract_from_file(self, filepath: str, language: str = None) -> List[str]:
        """Extrae funciones de un archivo."""
        if language is None:
            language = self.detect_language(filepath)
            if language is None:
                raise ValueError(f"No se pudo detectar el lenguaje para {filepath}")
        
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            code = f.read()
        
        return self.extract_functions(code, language)

# Instanciar el extractor para uso en el notebook
extractor = FunctionExtractor()

# Ejemplo de uso:
# functions = extractor.extract_functions(code_string, 'python')
# functions = extractor.extract_from_file('path/to/file.py')
print("FunctionExtractor cargado correctamente.")

In [None]:
# Celda NUEVA al final (ejecutar UNA VEZ para convertir):
print("Convirtiendo best_model.bin a formato save_pretrained...")

# Cargar modelo base
predictor = VulnerabilityPredictor(MODEL_PATH)

Convirtiendo best_model.bin a formato save_pretrained...


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at microsoft/codebert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('../../tokenizer_config.json',
 '../../special_tokens_map.json',
 '../../vocab.json',
 '../../merges.txt',
 '../../added_tokens.json',
 '../../tokenizer.json')

In [None]:
def analyze(code: str) -> Dict:
    result = predictor.predict(code)
    conf = result['confidence']
    
    recommendations = {
        'Safe': "Código seguro.",
        'CWE-79': "Posible XSS. Sanitiza outputs.",
        'CWE-89': "Posible SQL Injection. Usa queries parametrizadas.",
        'CWE-78': "Posible Command Injection. Evita shell=True.",
        'CWE-22': "Posible Path Traversal. Valida rutas.",
        'CWE-434': "Posible File Upload inseguro. Valida archivos.",
        'CWE-352': "Posible CSRF. Implementa tokens.",
        'Other': "Revisar manualmente."
    }
    
    return {
        'prediction': result['label'],
        'confidence': conf,
        'confidence_level': 'alta' if conf > 0.7 else 'media' if conf > 0.5 else 'baja',
        'is_vulnerable': result['is_vulnerable'],
        'risk': 'ALTO' if result['is_vulnerable'] and conf > 0.7 else 'MEDIO' if result['is_vulnerable'] else 'BAJO',
        'recommendation': recommendations.get(result['label'], "Revisar.")
    }

In [96]:
vuln_code = '''
def get_user(user_id):
    query = "SELECT * FROM users WHERE id = " + user_id
    cursor.execute(query)
'''
safe_code = '''
def get_user(user_id):
    cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
'''

In [None]:
# Predecir la clase del código usando el modelo entrenado
analysis = predictor.predict(vuln_code)
print("Análisis del código vulnerable:")
print(analysis)
analysis_s = predictor.predict(safe_code)
print("Análisis del código seguro:")
print(analysis_s)

Análisis del código vulnerable:
{'label': 'Safe', 'label_id': 0, 'confidence': 0.5147185921669006, 'is_vulnerable': False, 'probabilities': {'Safe': 0.5147185921669006, 'CWE-79': 0.012471646070480347, 'CWE-89': 0.0767425075173378, 'CWE-78': 0.0009228368871845305, 'CWE-22': 0.0010134786134585738, 'CWE-434': 0.00163276179227978, 'CWE-352': 0.0026988645549863577, 'Other': 0.3897992670536041}}
Análisis del código seguro:
{'label': 'Other', 'label_id': 7, 'confidence': 0.5140795111656189, 'is_vulnerable': True, 'probabilities': {'Safe': 0.4438059329986572, 'CWE-79': 0.012255201116204262, 'CWE-89': 0.025257835164666176, 'CWE-78': 0.0009352525230497122, 'CWE-22': 0.0007075853645801544, 'CWE-434': 0.001013824949041009, 'CWE-352': 0.0019448746461421251, 'Other': 0.5140795111656189}}


In [None]:
# Función con técnicas de IA explicable para detectar la línea de código vulnerable

In [None]:
# Generar reporte para el usuario final usando la función xAI sobre las funciones predichas como vulnerables

: 