# Advanced Bixie Model Training with Progress Monitoring

This notebook includes:
- Progress bars and monitoring
- Checkpointing to save progress
- Ability to resume from where it left off
- Multiple embedding models
- Binary file analysis
- Training data sources for binary files

In [1]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, accuracy_score, f1_score, roc_auc_score, precision_recall_curve, precision_score, recall_score
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import warnings
import json
import pickle
import time
from tqdm import tqdm
warnings.filterwarnings('ignore')

# Add project root to path
sys.path.insert(0, "/home/trashpanda/repos/bixie.ai/")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
import random
random.seed(42)

## 1. Enhanced Embedding Models with Progress Monitoring

In [2]:
from transformers import (
    AutoTokenizer, AutoModel, 
    RobertaTokenizer, RobertaModel,
    BertTokenizer, BertModel,
    DistilBertTokenizer, DistilBertModel,
    T5Tokenizer, T5EncoderModel
)
import torch.nn.functional as F

class MultiModelEmbedder:
    """Enhanced embedder with progress monitoring and checkpointing"""
    
    def __init__(self, model_name="microsoft/codebert-base", device=None):
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        print(f"Loading model: {model_name} on {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        
        self.max_length = 512
        self.embedding_dim = self.model.config.hidden_size
        print(f"Model loaded successfully. Embedding dimension: {self.embedding_dim}")
        
    def embed_text(self, text, pooling_strategy='mean'):
        """Embed text using different pooling strategies"""
        try:
            inputs = self.tokenizer(
                text, 
                truncation=True, 
                padding=True, 
                max_length=self.max_length,
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                hidden_states = outputs.last_hidden_state
                
                if pooling_strategy == 'mean':
                    embedding = hidden_states.mean(dim=1)
                elif pooling_strategy == 'cls':
                    embedding = hidden_states[:, 0, :]
                elif pooling_strategy == 'max':
                    embedding = hidden_states.max(dim=1)[0]
                elif pooling_strategy == 'attention':
                    attention_weights = F.softmax(hidden_states.mean(dim=-1), dim=-1)
                    embedding = (hidden_states * attention_weights.unsqueeze(-1)).sum(dim=1)
                
                return embedding.squeeze().cpu().numpy()
                
        except Exception as e:
            print(f"Embedding failed: {e}")
            return None
    
    def embed_binary_string(self, code_bytes):
        """Embed binary data by converting to text representation"""
        try:
            hex_str = code_bytes.hex()
            formatted_hex = ' '.join([hex_str[i:i+2] for i in range(0, len(hex_str), 2)])
            return self.embed_text(formatted_hex)
        except Exception as e:
            print(f"Binary embedding failed: {e}")
            return None

# Model configurations to test
MODEL_CONFIGS = {
    'codebert': 'microsoft/codebert-base',
    'roberta': 'roberta-base',
    'bert': 'bert-base-uncased',
    'distilbert': 'distilbert-base-uncased'
}

# Pooling strategies
POOLING_STRATEGIES = ['mean', 'cls', 'max']

## 2. Progress Monitoring and Checkpointing

In [3]:
class EmbeddingCheckpointer:
    """Manages checkpointing for embedding extraction"""
    
    def __init__(self, checkpoint_dir="../checkpoints"):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        
    def save_checkpoint(self, embeddings, labels, model_name, pooling, batch_idx, total_batches):
        """Save current progress"""
        checkpoint = {
            'embeddings': embeddings,
            'labels': labels,
            'model_name': model_name,
            'pooling': pooling,
            'batch_idx': batch_idx,
            'total_batches': total_batches,
            'timestamp': time.time()
        }
        
        checkpoint_file = self.checkpoint_dir / f"{model_name}_{pooling}_checkpoint.pkl"
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(checkpoint, f)
        
        print(f"Checkpoint saved: {checkpoint_file}")
        
    def load_checkpoint(self, model_name, pooling):
        """Load existing checkpoint"""
        checkpoint_file = self.checkpoint_dir / f"{model_name}_{pooling}_checkpoint.pkl"
        
        if checkpoint_file.exists():
            with open(checkpoint_file, 'rb') as f:
                checkpoint = pickle.load(f)
            print(f"Loaded checkpoint: {checkpoint_file}")
            print(f"Progress: {checkpoint['batch_idx']}/{checkpoint['total_batches']} batches")
            return checkpoint
        return None
    
    def clear_checkpoint(self, model_name, pooling):
        """Clear checkpoint after successful completion"""
        checkpoint_file = self.checkpoint_dir / f"{model_name}_{pooling}_checkpoint.pkl"
        if checkpoint_file.exists():
            checkpoint_file.unlink()
            print(f"Cleared checkpoint: {checkpoint_file}")

def extract_embeddings_with_progress(texts, labels, embedder, pooling_strategy='mean', 
                                     batch_size=50, checkpoint_every=10, resume=True):
    """Extract embeddings with progress monitoring and checkpointing"""
    
    checkpointer = EmbeddingCheckpointer()
    model_name = embedder.model_name.split('/')[-1]
    
    # Try to load existing checkpoint
    start_batch = 0
    X, y = [], []
    
    if resume:
        checkpoint = checkpointer.load_checkpoint(model_name, pooling_strategy)
        if checkpoint:
            X = checkpoint['embeddings']
            y = checkpoint['labels']
            start_batch = checkpoint['batch_idx']
            print(f"Resuming from batch {start_batch}")
    
    # Calculate batches
    total_samples = len(texts)
    total_batches = (total_samples + batch_size - 1) // batch_size
    
    print(f"Total samples: {total_samples}")
    print(f"Batch size: {batch_size}")
    print(f"Total batches: {total_batches}")
    print(f"Starting from batch: {start_batch}")
    
    failed_count = 0
    
    # Process batches with progress bar
    for batch_idx in tqdm(range(start_batch, total_batches), desc=f"Embedding ({model_name}, {pooling_strategy})"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_samples)
        
        batch_texts = texts[start_idx:end_idx]
        batch_labels = labels[start_idx:end_idx]
        
        batch_embeddings = []
        batch_labels_processed = []
        
        for i, (text, label) in enumerate(zip(batch_texts, batch_labels)):
            try:
                text_bytes = text.encode('utf-8')
                emb = embedder.embed_binary_string(text_bytes)
                
                if emb is not None:
                    batch_embeddings.append(emb)
                    batch_labels_processed.append(label)
                else:
                    failed_count += 1
                    
            except Exception as e:
                failed_count += 1
                print(f"Failed to embed sample {start_idx + i}: {e}")
        
        # Add batch results to main lists
        X.extend(batch_embeddings)
        y.extend(batch_labels_processed)
        
        # Save checkpoint periodically
        if (batch_idx + 1) % checkpoint_every == 0:
            checkpointer.save_checkpoint(X, y, model_name, pooling_strategy, batch_idx + 1, total_batches)
            
        # Print progress every 5 batches
        if (batch_idx + 1) % 5 == 0:
            print(f"Batch {batch_idx + 1}/{total_batches}: {len(X)} embeddings, {failed_count} failed")
    
    # Clear checkpoint after successful completion
    checkpointer.clear_checkpoint(model_name, pooling_strategy)
    
    print(f"\nEmbedding completed!")
    print(f"Successful embeddings: {len(X)}")
    print(f"Failed embeddings: {failed_count}")
    print(f"Success rate: {len(X)/(len(X)+failed_count)*100:.2f}%")
    
    return np.array(X), np.array(y), failed_count

## 3. Load and Prepare Data

In [4]:
# Load training data
with open("../datasets/training_data.json") as f:
    data = json.load(f)

texts = [item["code"] for item in data]
labels = [item["label"] for item in data]

print(f"Total samples: {len(texts)}")
print(f"Vulnerable samples: {sum(labels)}")
print(f"Clean samples: {len(labels) - sum(labels)}")
print(f"Class balance: {sum(labels)/len(labels):.3f}")

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels
)

print(f"\nTraining samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")

Total samples: 22734
Vulnerable samples: 2240
Clean samples: 20494
Class balance: 0.099

Training samples: 18187
Test samples: 4547


## 4. Comprehensive Model Evaluation with Progress

In [None]:
# Store results for comparison
all_results = []

# Test different models and pooling strategies
for model_name, model_path in MODEL_CONFIGS.items():
    print(f"\n{'='*60}")
    print(f"Testing {model_name}: {model_path}")
    print(f"{'='*60}")
    
    try:
        embedder = MultiModelEmbedder(model_path)
        
        for pooling in POOLING_STRATEGIES:
            print(f"\n--- Testing {pooling} pooling ---")
            
            # Extract embeddings with progress monitoring
            X_train_emb, y_train_emb, failed_train = extract_embeddings_with_progress(
                X_train, y_train, embedder, pooling, batch_size=50, checkpoint_every=10
            )
            
            X_test_emb, y_test_emb, failed_test = extract_embeddings_with_progress(
                X_test, y_test, embedder, pooling, batch_size=50, checkpoint_every=10
            )
            
            if len(X_train_emb) < 10:
                print(f"Skipping {model_name} with {pooling} - insufficient samples")
                continue
            
            print(f"\nTraining embeddings: {len(X_train_emb)}")
            print(f"Test embeddings: {len(X_test_emb)}")
            
            # Test multiple classifiers
            classifiers = {
                'RandomForest': RandomForestClassifier(n_estimators=100, random_state=42),
                'GradientBoosting': GradientBoostingClassifier(random_state=42),
                'LogisticRegression': LogisticRegression(random_state=42, max_iter=1000)
            }
            
            for clf_name, clf in classifiers.items():
                try:
                    print(f"\nTraining {clf_name}...")
                    clf.fit(X_train_emb, y_train_emb)
                    y_pred = clf.predict(X_test_emb)
                    y_proba = clf.predict_proba(X_test_emb)[:, 1] if hasattr(clf, 'predict_proba') else None
                    
                    results = evaluate_model(
                        y_test_emb, y_pred, y_proba, 
                        f"{model_name}_{pooling}_{clf_name}"
                    )
                    results['embedding_model'] = model_name
                    results['pooling'] = pooling
                    results['classifier'] = clf_name
                    all_results.append(results)
                    
                except Exception as e:
                    print(f"Failed to train {clf_name}: {e}")
                    
    except Exception as e:
        print(f"Failed to load {model_name}: {e}")


Testing codebert: microsoft/codebert-base
Loading model: microsoft/codebert-base on cpu
Model loaded successfully. Embedding dimension: 768

--- Testing mean pooling ---
Total samples: 18187
Batch size: 50
Total batches: 364
Starting from batch: 0


Embedding (codebert-base, mean):   0%|          | 0/364 [00:00<?, ?it/s]

## 5. Results Analysis and Visualization

In [None]:
def evaluate_model(y_true, y_pred, y_proba=None, model_name="Model"):
    """Comprehensive model evaluation"""
    
    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    
    # ROC AUC if probabilities available
    auc = None
    if y_proba is not None:
        auc = roc_auc_score(y_true, y_proba)
    
    results = {
        'model': model_name,
        'accuracy': accuracy,
        'f1_score': f1,
        'precision': precision,
        'recall': recall,
        'auc': auc
    }
    
    print(f"\n{model_name} Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    if auc:
        print(f"ROC AUC: {auc:.4f}")
    
    return results

# Convert results to DataFrame
results_df = pd.DataFrame(all_results)

# Display best performing models
print("\nTop 10 Models by F1 Score:")
top_models = results_df.nlargest(10, 'f1_score')
print(top_models[['model', 'embedding_model', 'pooling', 'classifier', 'f1_score', 'accuracy', 'auc']])

# Visualization
plt.figure(figsize=(15, 10))

# F1 Score comparison
plt.subplot(2, 2, 1)
sns.boxplot(data=results_df, x='embedding_model', y='f1_score')
plt.title('F1 Score by Embedding Model')
plt.xticks(rotation=45)

# Accuracy comparison
plt.subplot(2, 2, 2)
sns.boxplot(data=results_df, x='classifier', y='accuracy')
plt.title('Accuracy by Classifier')
plt.xticks(rotation=45)

# Pooling strategy comparison
plt.subplot(2, 2, 3)
sns.boxplot(data=results_df, x='pooling', y='f1_score')
plt.title('F1 Score by Pooling Strategy')

# ROC AUC comparison (if available)
plt.subplot(2, 2, 4)
auc_data = results_df.dropna(subset=['auc'])
if len(auc_data) > 0:
    sns.boxplot(data=auc_data, x='embedding_model', y='auc')
    plt.title('ROC AUC by Embedding Model')
    plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

## 6. Binary File Analysis

In [None]:
import subprocess

class BinaryAnalyzer:
    """Specialized analyzer for binary files"""
    
    def __init__(self, embedder):
        self.embedder = embedder
    
    def extract_strings(self, binary_path):
        """Extract strings from binary"""
        try:
            result = subprocess.run(['strings', str(binary_path)], 
                                   capture_output=True, text=True, timeout=30)
            return result.stdout
        except Exception as e:
            print(f"Failed to extract strings: {e}")
            return ""
    
    def extract_hex_dump(self, binary_path, max_bytes=4096):
        """Extract hex dump from binary"""
        try:
            with open(binary_path, 'rb') as f:
                data = f.read(max_bytes)
            return data.hex()
        except Exception as e:
            print(f"Failed to extract hex dump: {e}")
            return ""
    
    def extract_disassembly(self, binary_path):
        """Extract disassembly using objdump"""
        try:
            result = subprocess.run(['objdump', '-d', str(binary_path)], 
                                   capture_output=True, text=True, timeout=60)
            return result.stdout
        except Exception as e:
            print(f"Failed to extract disassembly: {e}")
            return ""
    
    def analyze_binary(self, binary_path):
        """Comprehensive binary analysis"""
        analysis = {}
        
        # Extract different representations
        analysis['strings'] = self.extract_strings(binary_path)
        analysis['hex_dump'] = self.extract_hex_dump(binary_path)
        analysis['disassembly'] = self.extract_disassembly(binary_path)
        
        # Create combined representation
        combined_text = f"STRINGS:\n{analysis['strings']}\n\nHEX:\n{analysis['hex_dump']}\n\nASM:\n{analysis['disassembly']}"
        
        # Embed the combined representation
        embedding = self.embedder.embed_text(combined_text)
        
        return embedding, analysis

# Example usage
print("Binary analyzer initialized with multiple extraction methods")
print("Methods available:")
print("- String extraction")
print("- Hex dump extraction")
print("- Disassembly extraction")
print("- Combined analysis with embedding")

## 7. Training Data Sources for Binary Files

In [None]:
print("""
## Binary Vulnerability Training Data Sources:

### 1. Public Vulnerability Databases:
- CVE Database (https://cve.mitre.org/)
- NVD (https://nvd.nist.gov/)
- Exploit-DB (https://www.exploit-db.com/)
- SecurityFocus (https://www.securityfocus.com/)

### 2. Malware Datasets:
- VirusTotal (https://www.virustotal.com/) - API access
- MalwareBazaar (https://bazaar.abuse.ch/)
- VX-Underground (https://vx-underground.org/)
- MalShare (https://malshare.com/)

### 3. CTF and Security Challenges:
- Pwnable.kr
- Pwnable.tw
- HackTheBox binaries
- VulnHub challenges

### 4. Academic Datasets:
- Microsoft Malware Classification Challenge
- Drebin Dataset (Android malware)
- Malimg Dataset
- EMBER Dataset (https://github.com/endgameinc/ember)

### 5. Open Source Projects:
- Known vulnerable versions of popular software
- Historical releases with known CVEs
- Fuzzing results from OSS-Fuzz

### 6. Data Collection Strategies:
1. Version comparison: Compare vulnerable vs patched versions
2. Fuzzing: Generate crash samples
3. Symbolic execution: Extract vulnerable paths
4. Static analysis: Identify vulnerable patterns
5. Dynamic analysis: Runtime vulnerability detection

### 7. Labeling Strategies:
- CVE mapping to binary versions
- Exploit availability as vulnerability indicator
- Patch analysis for vulnerability confirmation
- Static analysis tool results
- Dynamic analysis results
""")

## 8. Enhanced Model Inference Implementation

In [None]:
print("""
# Enhanced model_inference.py Implementation

```python
import os
import sys
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional, Union
import hashlib
import time
import logging
import joblib
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import subprocess
import tempfile
from tqdm import tqdm
from bixie.vector_store.chroma_store import ChromaStore

# Constants
DEFAULT_MODEL_NAME = "microsoft/codebert-base"
CLASSIFIER_PATH = Path("bixie.ai/fine_tuned_model.pkl")
NEURAL_MODEL_PATH = Path("bixie.ai/neural_classifier.pth")

logger = logging.getLogger("bixie.model_inference")

class MultiModelEmbedder:
    """Enhanced embedder supporting multiple models and binary analysis"""
    
    def __init__(self, model_name: str = DEFAULT_MODEL_NAME, device=None):
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        logger.info(f"Loading model: {model_name} on {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        
        self.max_length = 512
        self.embedding_dim = self.model.config.hidden_size
        logger.info(f"Model loaded successfully. Embedding dimension: {self.embedding_dim}")
    
    def embed_text(self, text: str, pooling_strategy: str = 'mean') -> Optional[np.ndarray]:
        """Embed text using different pooling strategies"""
        try:
            inputs = self.tokenizer(
                text, 
                truncation=True, 
                padding=True, 
                max_length=self.max_length,
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                hidden_states = outputs.last_hidden_state
                
                if pooling_strategy == 'mean':
                    embedding = hidden_states.mean(dim=1)
                elif pooling_strategy == 'cls':
                    embedding = hidden_states[:, 0, :]
                elif pooling_strategy == 'max':
                    embedding = hidden_states.max(dim=1)[0]
                elif pooling_strategy == 'attention':
                    attention_weights = torch.softmax(hidden_states.mean(dim=-1), dim=-1)
                    embedding = (hidden_states * attention_weights.unsqueeze(-1)).sum(dim=1)
                
                return embedding.squeeze().cpu().numpy()
                
        except Exception as e:
            logger.error(f"Text embedding failed: {e}")
            return None
    
    def embed_binary_string(self, code_bytes: bytes) -> Optional[np.ndarray]:
        """Embed binary data"""
        try:
            hex_str = code_bytes.hex()
            formatted_hex = ' '.join([hex_str[i:i+2] for i in range(0, len(hex_str), 2)])
            return self.embed_text(formatted_hex)
        except Exception as e:
            logger.error(f"Binary embedding failed: {e}")
            return None
    
    def embed_file(self, filepath: Path) -> Optional[np.ndarray]:
        """Embed file with automatic format detection"""
        try:
            with open(filepath, "rb") as f:
                content = f.read()
            
            # Try to decode as text first
            try:
                text_content = content.decode('utf-8')
                return self.embed_text(text_content)
            except UnicodeDecodeError:
                # Binary file - use binary embedding
                return self.embed_binary_string(content)
                
        except Exception as e:
            logger.error(f"Failed to embed {filepath}: {e}")
            return None

class BinaryAnalyzer:
    """Specialized binary file analyzer"""
    
    def __init__(self, embedder: MultiModelEmbedder):
        self.embedder = embedder
    
    def extract_strings(self, binary_path: Path) -> str:
        """Extract strings from binary"""
        try:
            result = subprocess.run(['strings', str(binary_path)], 
                                   capture_output=True, text=True, timeout=30)
            return result.stdout
        except Exception as e:
            logger.warning(f"Failed to extract strings: {e}")
            return ""
    
    def extract_hex_dump(self, binary_path: Path, max_bytes: int = 4096) -> str:
        """Extract hex dump"""
        try:
            with open(binary_path, 'rb') as f:
                data = f.read(max_bytes)
            return data.hex()
        except Exception as e:
            logger.warning(f"Failed to extract hex dump: {e}")
            return ""
    
    def extract_disassembly(self, binary_path: Path) -> str:
        """Extract disassembly"""
        try:
            result = subprocess.run(['objdump', '-d', str(binary_path)], 
                                   capture_output=True, text=True, timeout=60)
            return result.stdout
        except Exception as e:
            logger.warning(f"Failed to extract disassembly: {e}")
            return ""
    
    def analyze_binary(self, binary_path: Path) -> Optional[np.ndarray]:
        """Comprehensive binary analysis"""
        try:
            strings = self.extract_strings(binary_path)
            hex_dump = self.extract_hex_dump(binary_path)
            disassembly = self.extract_disassembly(binary_path)
            
            # Combine all representations
            combined_text = f"STRINGS:\n{strings}\n\nHEX:\n{hex_dump}\n\nASM:\n{disassembly}"
            
            return self.embedder.embed_text(combined_text)
            
        except Exception as e:
            logger.error(f"Binary analysis failed for {binary_path}: {e}")
            return None

class VulnerabilityClassifier(nn.Module):
    """Neural network classifier"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int] = [512, 256, 128], dropout: float = 0.3):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.network = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.network(x))

def load_classifier(classifier_path: Path = CLASSIFIER_PATH) -> Optional[Union[object, VulnerabilityClassifier]]:
    """Load trained classifier (traditional ML or neural network)"""
    if classifier_path.exists():
        try:
            clf = joblib.load(classifier_path)
            logger.info(f"Loaded traditional classifier from {classifier_path}")
            return clf
        except Exception as e:
            logger.error(f"Failed to load traditional classifier: {e}")
    
    # Try loading neural network
    neural_path = NEURAL_MODEL_PATH
    if neural_path.exists():
        try:
            # Load model architecture and weights
            # This would need to be implemented based on saved model structure
            logger.info(f"Loaded neural classifier from {neural_path}")
            return None  # Placeholder
        except Exception as e:
            logger.error(f"Failed to load neural classifier: {e}")
    
    logger.warning(f"No trained classifier found")
    return None

def run_model_inference(
    target_paths: List[Path],
    output_dir: Optional[Path] = None,
    save_vectors: bool = True,
    chroma_store: Optional[ChromaStore] = None,
    model_name: str = DEFAULT_MODEL_NAME,
    pooling_strategy: str = 'mean',
    batch_size: int = 50,
    show_progress: bool = True
) -> List[Dict]:
    """
    Run ML model inference with progress monitoring and checkpointing
    """
    embedder = MultiModelEmbedder(model_name)
    classifier = load_classifier()
    binary_analyzer = BinaryAnalyzer(embedder)
    results = []
    
    # Process files with progress bar
    iterator = tqdm(target_paths, desc="Processing files") if show_progress else target_paths
    
    for path in iterator:
        if not path.is_file():
            logger.warning(f"Skipping non-file target: {path}")
            continue
        
        try:
            # Check if file is already processed
            file_hash = hashlib.sha256(path.read_bytes()).hexdigest()
            
            if chroma_store and chroma_store.exists(file_hash):
                logger.info(f"Skipping already processed file: {path}")
                continue
            
            # Analyze binary
            vector = binary_analyzer.analyze_binary(path)
            
            if vector is None:
                logger.warning(f"Failed to analyze {path}")
                continue
            
            # Save to Chroma store
            if chroma_store:
                chroma_store.add(file_hash, vector)
            
            # Run classifier
            if classifier:
                prediction = classifier.predict([vector])[0]
                result = {
                    'path': str(path),
                    'hash': file_hash,
                    'prediction': prediction,
                    'vector': vector.tolist()
                }
            
            # Save to output directory
            if output_dir:
                output_path = output_dir / path.name
                output_path.write_bytes(path.read_bytes())
            
            results.append(result)
            
        except Exception as e:
            logger.error(f"Error processing {path}: {e}")
            continue
    
    return results
```
""")

## 9. Enhanced Model Inference Implementation

In [None]:
print("""
# Enhanced model_inference.py Implementation

```python
import os
import sys
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional, Union
import hashlib
import time
import logging
import joblib
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import subprocess
import tempfile
from tqdm import tqdm
from bixie.vector_store.chroma_store import ChromaStore

# Constants
DEFAULT_MODEL_NAME = "microsoft/codebert-base"
CLASSIFIER_PATH = Path("bixie.ai/fine_tuned_model.pkl")
NEURAL_MODEL_PATH = Path("bixie.ai/neural_classifier.pth")

logger = logging.getLogger("bixie.model_inference")

class MultiModelEmbedder:
    """Enhanced embedder supporting multiple models and binary analysis"""
    
    def __init__(self, model_name: str = DEFAULT_MODEL_NAME, device=None):
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        logger.info(f"Loading model: {model_name} on {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        
        self.max_length = 512
        self.embedding_dim = self.model.config.hidden_size
        logger.info(f"Model loaded successfully. Embedding dimension: {self.embedding_dim}")
    
    def embed_text(self, text: str, pooling_strategy: str = 'mean') -> Optional[np.ndarray]:
        """Embed text using different pooling strategies"""
        try:
            inputs = self.tokenizer(
                text, 
                truncation=True, 
                padding=True, 
                max_length=self.max_length,
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                hidden_states = outputs.last_hidden_state
                
                if pooling_strategy == 'mean':
                    embedding = hidden_states.mean(dim=1)
                elif pooling_strategy == 'cls':
                    embedding = hidden_states[:, 0, :]
                elif pooling_strategy == 'max':
                    embedding = hidden_states.max(dim=1)[0]
                elif pooling_strategy == 'attention':
                    attention_weights = torch.softmax(hidden_states.mean(dim=-1), dim=-1)
                    embedding = (hidden_states * attention_weights.unsqueeze(-1)).sum(dim=1)
                
                return embedding.squeeze().cpu().numpy()
                
        except Exception as e:
            logger.error(f"Text embedding failed: {e}")
            return None
    
    def embed_binary_string(self, code_bytes: bytes) -> Optional[np.ndarray]:
        """Embed binary data"""
        try:
            hex_str = code_bytes.hex()
            formatted_hex = ' '.join([hex_str[i:i+2] for i in range(0, len(hex_str), 2)])
            return self.embed_text(formatted_hex)
        except Exception as e:
            logger.error(f"Binary embedding failed: {e}")
            return None
    
    def embed_file(self, filepath: Path) -> Optional[np.ndarray]:
        """Embed file with automatic format detection"""
        try:
            with open(filepath, "rb") as f:
                content = f.read()
            
            # Try to decode as text first
            try:
                text_content = content.decode('utf-8')
                return self.embed_text(text_content)
            except UnicodeDecodeError:
                # Binary file - use binary embedding
                return self.embed_binary_string(content)
                
        except Exception as e:
            logger.error(f"Failed to embed {filepath}: {e}")
            return None

class BinaryAnalyzer:
    """Specialized binary file analyzer"""
    
    def __init__(self, embedder: MultiModelEmbedder):
        self.embedder = embedder
    
    def extract_strings(self, binary_path: Path) -> str:
        """Extract strings from binary"""
        try:
            result = subprocess.run(['strings', str(binary_path)], 
                                   capture_output=True, text=True, timeout=30)
            return result.stdout
        except Exception as e:
            logger.warning(f"Failed to extract strings: {e}")
            return ""
    
    def extract_hex_dump(self, binary_path: Path, max_bytes: int = 4096) -> str:
        """Extract hex dump"""
        try:
            with open(binary_path, 'rb') as f:
                data = f.read(max_bytes)
            return data.hex()
        except Exception as e:
            logger.warning(f"Failed to extract hex dump: {e}")
            return ""
    
    def extract_disassembly(self, binary_path: Path) -> str:
        """Extract disassembly"""
        try:
            result = subprocess.run(['objdump', '-d', str(binary_path)], 
                                   capture_output=True, text=True, timeout=60)
            return result.stdout
        except Exception as e:
            logger.warning(f"Failed to extract disassembly: {e}")
            return ""
    
    def analyze_binary(self, binary_path: Path) -> Optional[np.ndarray]:
        """Comprehensive binary analysis"""
        try:
            strings = self.extract_strings(binary_path)
            hex_dump = self.extract_hex_dump(binary_path)
            disassembly = self.extract_disassembly(binary_path)
            
            # Combine all representations
            combined_text = f"STRINGS:\n{strings}\n\nHEX:\n{hex_dump}\n\nASM:\n{disassembly}"
            
            return self.embedder.embed_text(combined_text)
            
        except Exception as e:
            logger.error(f"Binary analysis failed for {binary_path}: {e}")
            return None

class VulnerabilityClassifier(nn.Module):
    """Neural network classifier"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int] = [512, 256, 128], dropout: float = 0.3):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.network = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.network(x))

def load_classifier(classifier_path: Path = CLASSIFIER_PATH) -> Optional[Union[object, VulnerabilityClassifier]]:
    """Load trained classifier (traditional ML or neural network)"""
    if classifier_path.exists():
        try:
            clf = joblib.load(classifier_path)
            logger.info(f"Loaded traditional classifier from {classifier_path}")
            return clf
        except Exception as e:
            logger.error(f"Failed to load traditional classifier: {e}")
    
    # Try loading neural network
    neural_path = NEURAL_MODEL_PATH
    if neural_path.exists():
        try:
            # Load model architecture and weights
            # This would need to be implemented based on saved model structure
            logger.info(f"Loaded neural classifier from {neural_path}")
            return None  # Placeholder
        except Exception as e:
            logger.error(f"Failed to load neural classifier: {e}")
    
    logger.warning(f"No trained classifier found")
    return None

def run_model_inference(
    target_paths: List[Path],
    output_dir: Optional[Path] = None,
    save_vectors: bool = True,
    chroma_store: Optional[ChromaStore] = None,
    model_name: str = DEFAULT_MODEL_NAME,
    pooling_strategy: str = 'mean',
    batch_size: int = 50,
    show_progress: bool = True
) -> List[Dict]:
    """
    Run ML model inference with progress monitoring and checkpointing
    """
    embedder = MultiModelEmbedder(model_name)
    classifier = load_classifier()
    binary_analyzer = BinaryAnalyzer(embedder)
    results = []
    
    # Process files with progress bar
    iterator = tqdm(target_paths, desc="Processing files") if show_progress else target_paths
    
    for path in iterator:
        if not path.is_file():
            logger.warning(f"Skipping non-file target: {path}")
            continue
        
        try:
            # Check if file is already processed
            file_hash = hashlib.sha256(path.read_bytes()).hexdigest()
            
            if chroma_store and chroma_store.exists(file_hash):
                logger.info(f"Skipping already processed file: {path}")
                continue
            
            # Analyze binary
            vector = binary_analyzer.analyze_binary(path)
            
            if vector is None:
                logger.warning(f"Failed to analyze {path}")
                continue
            
            # Save to Chroma store
            if chroma_store:
                chroma_store.add(file_hash, vector)
            
            # Run classifier
            if classifier:
                prediction = classifier.predict([vector])[0]
                result = {
                    'path': str(path),
                    'hash': file_hash,
                    'prediction': prediction,
                    'vector': vector.tolist()
                }
            
            # Save to output directory
            if output_dir:
                output_path = output_dir / path.name
                output_path.write_bytes(path.read_bytes())
            
            results.append(result)
            
        except Exception as e:
            logger.error(f"Error processing {path}: {e}")
            continue
    
    return results
```
""")