In [143]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from flax.training import train_state

import optax

SEED = 0
rng = jax.random.PRNGKey(SEED)

In [144]:
class ByteClassifier(nn.Module):
    num_bytes: int
    embedding_dim: int

    @nn.compact
    def __call__(self, inputs, train=True):
        # Input: (batch_size, num_bytes) of raw byte values (0-255)
        x = nn.Embed(256, self.embedding_dim)(inputs.astype(jnp.int32))  # Ensure integers

        # Temporal convolution to capture local patterns
        x = nn.Conv(features=64, kernel_size=(5,), padding='SAME')(x)
        x = nn.relu(x)

        # Attention layer to focus on important positions
        attn = nn.SelfAttention(num_heads=4)(x)
        x = jnp.concatenate([x, attn], axis=-1)

        # Final dense layers with dropout
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.3, deterministic=not train)(x, rng=rng)

        # Output layer
        x = nn.Dense(1)(x)
        return nn.sigmoid(x).squeeze(-1)

In [145]:
# Initialize model
def create_model():
    return ByteClassifier(
        num_bytes=256,  # Example: analyze 256-byte chunks
        embedding_dim=16
    )

# Create initial state
def initialize_model(key, input_shape=(256,)):
    model = create_model()
    dummy_input = jnp.zeros((1, *input_shape), dtype=jnp.int32)  # Ensure integers
    params = model.init(key, dummy_input)['params']
    return model, params

def create_optimizer(learning_rate=1e-3):
    return optax.adamw(
        learning_rate=learning_rate,
        b1=0.9,
        b2=0.999,
        weight_decay=1e-5
    )

In [158]:
# Initialize
model, params = initialize_model(rng)
optimizer = create_optimizer(1e-3)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer
)

def loss_fn(params, inputs, labels, overestimate_weight=1.0, underestimate_weight=1.0):
    # Forward pass: Calculate logits (predictions)
    logits = model.apply({'params': params}, inputs)
    print(logits.shape)
    
    # Calculate error: (prediction - true label)
    error = logits - labels
    
    # Apply different weights for overestimating and underestimating errors
    overestimating_error = (error > 0)  # True when overestimated (logits > labels)
    underestimating_error = (error < 0)  # True when underestimated (logits < labels)
    
    # Compute MSE for each case
    mse_overestimating = jnp.where(overestimating_error, (error ** 2) * overestimate_weight, 0)
    mse_underestimating = jnp.where(underestimating_error, (error ** 2) * underestimate_weight, 0)
    
    # Total loss: sum of weighted MSE errors for both types
    total_loss = jnp.mean(mse_overestimating + mse_underestimating)
    
    return total_loss

In [159]:
@jax.jit
def train_step(state, batch):
    inputs, labels = batch
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, inputs, labels)
    state = state.apply_gradients(grads=grads)
    return state, loss

def sliding_window_data(sequence, window_size, step_size=1):
    inputs, labels = [], []
    for i in range(0, len(sequence) - window_size + 1, step_size):
        input_window = sequence[i:i + window_size]
        label_window = sequence[i:i + window_size] 
        inputs.append(input_window)
        labels.append(label_window)
    
    return inputs, labels

def train_model(state, sequence, window_size=10, step_size=1, num_epochs=10):
    # Generate sliding windows for inputs and labels
    inputs, labels = sliding_window_data(sequence, window_size, step_size)

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch_inputs, batch_labels in zip(inputs, labels):
            batch = (batch_inputs, batch_labels)  # Prepare batch
            state, loss = train_step(state, batch)
            epoch_loss += loss
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss/len(inputs):.4f}")
    
    return state

In [160]:
import json

"""
    {
        script_name: {
            "bytes": [[252], [72, ... 240] ... [104, 126, 162, 208, 83]],
            "data_addresses": [48, 50, 52, ... 526, 528]
        }

        ...
    }
    """

def load_data(path):
    with open(path, 'r') as f:
        data = json.load(f)
        
    for script_name, script_data in data.items():
        try:
            bytes = script_data['bytes']
            flattened_bytes = [b for instruction in bytes for b in instruction]
            
            data_addresses = script_data['data_addresses']
            
            x = jnp.array(flattened_bytes)
            
            y = np.zeros_like(x)
            y[data_addresses] = 1
            y = jnp.array(y)
        except:
            continue
        
        yield script_name, x, y

In [162]:
# Load data
data_path = 'payloads_dict.json'
exec_data = list(load_data(data_path))


In [163]:
data = (exec_data[0][1], exec_data[0][2])
print(data[0].shape, data[1].shape)

(571,) (571,)


In [165]:
# Train model
state = train_model(state, data, window_size=256, step_size=1, num_epochs=10)

ZeroDivisionError: float division by zero