In [87]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
import numpy as np
from typing import List, Dict
from transformers import AutoModelForSequenceClassification

import os
import sys

sys.path.append(os.path.abspath('..'))

from src.v3.map_cwe import LABEL_NAMES, NUM_LABELS
from src.v3.model import VulnClassifier

In [88]:
MAX_LEN = 512
STRIDE = 256
DEVICE = 'cpu'
MODEL_NAME = "microsoft/codebert-base"

In [89]:
# Rutas del archivo de código para probar y del modelo entrenado
OUTPUT_DIR = '../../'

In [90]:
# Division del arhcivo código en funciones [code: string]

In [91]:
# Celda NUEVA al final (ejecutar UNA VEZ para convertir):
print("Convirtiendo best_model.bin a formato save_pretrained...")

# Cargar modelo base
model_convert = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, 
    num_labels=NUM_LABELS
)

# Cargar pesos entrenados
model_convert.load_state_dict(
    torch.load(f"{OUTPUT_DIR}/best_model.bin", map_location='cpu')
)

# Guardar completo
model_convert.save_pretrained(OUTPUT_DIR)

# También el tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(OUTPUT_DIR)


Convirtiendo best_model.bin a formato save_pretrained...


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at microsoft/codebert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('../../tokenizer_config.json',
 '../../special_tokens_map.json',
 '../../vocab.json',
 '../../merges.txt',
 '../../added_tokens.json',
 '../../tokenizer.json')

In [92]:
def tokenize(code: str) -> List[Dict]:
    tokens = tokenizer.encode(code, add_special_tokens=False, truncation=False)
    eff_len = MAX_LEN - 2
    
    if len(tokens) <= eff_len:
        enc = tokenizer.encode_plus(code, max_length=MAX_LEN,
                                            padding='max_length', truncation=True, return_tensors='pt')
        return [{'input_ids': enc['input_ids'].flatten(), 
                    'attention_mask': enc['attention_mask'].flatten()}]
    
    windows = []
    start = 0
    while start < len(tokens):
        end = min(start + eff_len, len(tokens))
        ids = [tokenizer.cls_token_id] + tokens[start:end] + [tokenizer.sep_token_id]
        pad = MAX_LEN - len(ids)
        windows.append({
            'input_ids': torch.tensor(ids + [tokenizer.pad_token_id] * pad),
            'attention_mask': torch.tensor([1] * len(ids) + [0] * pad)
        })
        start += STRIDE
        if len(tokens) - start < eff_len // 4:
            break
    return windows

In [93]:
def predict_batch(codes: List[str]) -> List[Dict]:
    all_ids, all_masks = [], []
    
    for code in codes:
        # Tokenizar cada código (sin sliding window manual)
        enc = tokenizer.encode_plus(
            code,
            max_length=MAX_LEN,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        all_ids.append(enc['input_ids'].squeeze(0))
        all_masks.append(enc['attention_mask'].squeeze(0))
    
    input_ids = torch.stack(all_ids).to(DEVICE)
    attention_mask = torch.stack(all_masks).to(DEVICE)
    
    model_convert.eval()
    with torch.no_grad():
        # AutoModelForSequenceClassification retorna un objeto con .logits
        outputs = model_convert(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1).cpu().numpy()
    
    results = []
    for j in range(len(codes)):
        pred_id = int(np.argmax(probs[j]))
        results.append({
            'label': LABEL_NAMES[pred_id],
            'label_id': pred_id,
            'confidence': float(probs[j, pred_id]),
            'is_vulnerable': pred_id != 0,
            'probabilities': {n: float(probs[j, k]) for k, n in enumerate(LABEL_NAMES)}
        })
    return results

In [94]:
def predict(code: str) -> Dict:
        return predict_batch([code])[0]

In [95]:
def analyze(code: str) -> Dict:
    result = predict(code)
    conf = result['confidence']
    
    recommendations = {
        'Safe': "Código seguro.",
        'CWE-79': "Posible XSS. Sanitiza outputs.",
        'CWE-89': "Posible SQL Injection. Usa queries parametrizadas.",
        'CWE-78': "Posible Command Injection. Evita shell=True.",
        'CWE-22': "Posible Path Traversal. Valida rutas.",
        'CWE-434': "Posible File Upload inseguro. Valida archivos.",
        'CWE-352': "Posible CSRF. Implementa tokens.",
        'Other': "Revisar manualmente."
    }
    
    return {
        'prediction': result['label'],
        'confidence': conf,
        'confidence_level': 'alta' if conf > 0.7 else 'media' if conf > 0.5 else 'baja',
        'is_vulnerable': result['is_vulnerable'],
        'risk': 'ALTO' if result['is_vulnerable'] and conf > 0.7 else 'MEDIO' if result['is_vulnerable'] else 'BAJO',
        'recommendation': recommendations.get(result['label'], "Revisar.")
    }

In [96]:
vuln_code = '''
def get_user(user_id):
    query = "SELECT * FROM users WHERE id = " + user_id
    cursor.execute(query)
'''
safe_code = '''
def get_user(user_id):
    cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
'''

In [97]:
# Predecir la clase del código usando el modelo entrenado
analysis = predict(vuln_code)
print("Análisis del código vulnerable:")
print(analysis)
analysis_s = predict(safe_code)
print("Análisis del código seguro:")
print(analysis_s)

Análisis del código vulnerable:
{'label': 'Safe', 'label_id': 0, 'confidence': 0.5147185921669006, 'is_vulnerable': False, 'probabilities': {'Safe': 0.5147185921669006, 'CWE-79': 0.012471646070480347, 'CWE-89': 0.0767425075173378, 'CWE-78': 0.0009228368871845305, 'CWE-22': 0.0010134786134585738, 'CWE-434': 0.00163276179227978, 'CWE-352': 0.0026988645549863577, 'Other': 0.3897992670536041}}
Análisis del código seguro:
{'label': 'Other', 'label_id': 7, 'confidence': 0.5140795111656189, 'is_vulnerable': True, 'probabilities': {'Safe': 0.4438059329986572, 'CWE-79': 0.012255201116204262, 'CWE-89': 0.025257835164666176, 'CWE-78': 0.0009352525230497122, 'CWE-22': 0.0007075853645801544, 'CWE-434': 0.001013824949041009, 'CWE-352': 0.0019448746461421251, 'Other': 0.5140795111656189}}


In [98]:
# Función con técnicas de IA explicable para detectar la línea de código vulnerable

In [None]:
# Generar reporte para el usuario final usando la función xAI sobre las funciones predichas como vulnerables

: 