In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
from gym import spaces
import random

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import ActorCriticPolicy


from transformers import AutoTokenizer, AutoModel

import json


In [2]:
class SmartContractVulnEnv(gym.Env):
    """
    A toy environment where each step involves:
      - presenting a code snippet (observation)
      - the agent chooses a vulnerability class (action)
      - reward is +1 if correct, -1 otherwise
    """
    def __init__(self, dataset, tokenizer, max_length=128):
        super(SmartContractVulnEnv, self).__init__()
        
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Suppose we have N distinct vulnerability types
        self.vuln_classes = sorted(list(set(item['label'] for item in dataset)))
        self.num_actions = len(self.vuln_classes)
        
        # Observation space can be a discrete or box, but in RL for text 
        # we often use a placeholder shape, and do custom processing later.
        # We'll store the raw text as part of the environment's internal state.
        # For stable-baselines, though, we need a well-defined space:
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(self.max_length,), dtype=np.int32
        )
        
        # Action space is the index of the predicted vulnerability
        self.action_space = spaces.Discrete(self.num_actions)
        
        self.current_index = 0
        self.current_text = None
        self.current_label = None
        self.episode_over = False
        
    def reset(self):
        """
        Resets the environment - we'll just pick a new sample from the dataset.
        Returns the tokenized observation.
        """
        self.current_index = random.randint(0, len(self.dataset) - 1)
        sample = self.dataset[self.current_index]
        
        self.current_text = sample['code']   # raw code snippet
        self.current_label = sample['label'] # e.g. 'reentrancy'
        self.episode_over = False
        
        # Return tokenized version as initial observation
        obs = self._tokenize_text(self.current_text)
        return obs
    
    def step(self, action):
        """
        Takes an action (predicted vulnerability class index),
        compares it with the true label => compute reward.
        """
        # Map integer action back to label
        predicted_label = self.vuln_classes[action]
        
        reward = 1.0 if predicted_label == self.current_label else -1.0
        
        # In this simple design, we end the episode after one classification
        self.episode_over = True
        done = True
        
        # For Gym, info can carry debugging info
        info = {'correct_label': self.current_label, 
                'predicted_label': predicted_label}
        
        # Return next observation (in a real multi-step scenario, you'd choose next snippet)
        # Here we just do a single step, so next obs is meaningless
        next_obs = np.zeros(self.max_length, dtype=np.int32)
        
        return next_obs, reward, done, info
    
    def _tokenize_text(self, text):
        """
        Convert text to token IDs for up to max_length.
        We'll return a fixed-size np.array of shape (max_length,).
        """
        encoded = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        # We'll just take input_ids and flatten into 1D
        input_ids = encoded['input_ids'].squeeze(0).detach().cpu().numpy()
        return input_ids
    
    def render(self, mode='human'):
        pass

In [3]:
class CodeBERTFeatureExtractor(BaseFeaturesExtractor):
    """
    This extracts embeddings from CodeBERT. 
    We'll feed these embeddings to the Actor-Critic heads for PPO.
    """
    def __init__(self, observation_space, model_name_or_path="fine_tuned_model", device="cuda"):
        # The super class expects a flat (N,) shape for observation space.
        super(CodeBERTFeatureExtractor, self).__init__(observation_space, features_dim=768)
        
        self.device = device
        
        # Load the CodeBERT model (this could be your fine-tuned model folder)
        # E.g. "my-finetuned-codebert" if you have a local folder
        self.codebert_model = AutoModel.from_pretrained(model_name_or_path).to(self.device)
        
        # The hidden size for codebert-base is typically 768.
        self._features_dim = 768

    def forward(self, observations):
        """
        observations: Tensor of shape (batch_size, max_length) with token IDs.
        """
        # Move observations to device
        input_ids = observations.long().to(self.device)
        
        # CodeBERT expects attention_mask as well, so we’ll create one
        attention_mask = (input_ids != 0).long().to(self.device)
        
        outputs = self.codebert_model(input_ids=input_ids, attention_mask=attention_mask)
        
        # outputs.last_hidden_state: [batch_size, seq_len, hidden_dim]
        # We can take the [CLS] token representation or pool in some way.
        last_hidden_state = outputs.last_hidden_state
        
        # Let's take the first token embedding (CLS) as a simple representation
        cls_embedding = last_hidden_state[:, 0, :]  # shape: [batch_size, hidden_dim]
        
        return cls_embedding

In [4]:
class CodeBERTPolicy(ActorCriticPolicy):
    """
    We override the default MLP extractor with our CodeBERT extractor.
    """
    def __init__(self, *args, **kwargs):
        super(CodeBERTPolicy, self).__init__(
            *args, 
            features_extractor_class=CodeBERTFeatureExtractor,
            # If your fine-tuned checkpoint is in a local folder, specify it here:
            features_extractor_kwargs={"model_name_or_path": "fine_tuned_model",
                                       "device":"cuda"}, 
            **kwargs
        )

In [5]:
def load_data(json_file):
    """
    Load the JSON file and return a list of (code_snippet, label_id).
    label_map is a dict mapping vulnerability_type -> integer (e.g. "safe" -> 0).
    """
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    samples = []
    for item in data:
        code = item["code_snippet"]
        label = item["vulnerability_type"]
        if label == "tx.origin Authentication":
            continue
        # Convert vulnerability_type to label ID
        samples.append({"code":code, "label":label})
    return samples

In [None]:
json_file = "train_data.json"
dummy_data = load_data(json_file)

tokenizer = AutoTokenizer.from_pretrained("fine_tuned_model")

env = SmartContractVulnEnv(dataset=dummy_data, tokenizer=tokenizer)
vec_env = DummyVecEnv([lambda: env])

model = PPO(
        policy=CodeBERTPolicy,
        env=vec_env,
        verbose=1,
        device='cuda',
        # You might want to tune learning_rate, batch_size, etc.
        learning_rate=1e-5,
        n_steps=8,         # small for illustration
        batch_size=4,      # small for illustration
        ent_coef=0.01,
    )

model.learn(total_timesteps=40000)
model.save("ppo_trained.zip")
print("RL model saved to ppo_codebert_vuln_detector.zip")

In [None]:
test_data_file = "data_test.json"
from sklearn.metrics import accuracy_score, classification_report

def evaluate_ppo_model(model, env, n_eval_episodes=100):
    """
    Run `n_eval_episodes` episodes in the environment to collect
    predictions vs. ground truths. Compute accuracy and a classification report.
    """
    all_preds = []
    all_labels = []
    
    for _ in range(n_eval_episodes):
        obs = env.reset()
        action, _ = model.predict(obs, deterministic=True)
        _, reward, done, info = env.step(action)
        
        all_preds.append(info["predicted_label"])
        all_labels.append(info["correct_label"])
    
    acc = accuracy_score(all_labels, all_preds)
    print(f"Accuracy: {acc:.4f}")
    print("Classification Report:")
    print(classification_report(all_labels, all_preds))

test_data = load_data(test_data_file)
test_env = SmartContractVulnEnv(dataset=test_data, tokenizer=tokenizer)
evaluate_ppo_model(model,test_env,n_eval_episodes=len(test_data))