# SGuard Suffix Attack Training

This notebook runs the SGuard Suffix Attack training and saves results to Google Drive.

## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Setup Environment
Clone the repository and install dependencies.

In [None]:
# Clone the repository
# TODO: Replace with your actual repository URL
!rm -rf /content/sguard_attack
!git clone https://github.com/2CHAN0/SGuardSuffix.git sguard_attack

# Install dependencies
!pip install -r sguard_attack/requirements.txt

## 3. Model Caching (Google Drive)
Check if the model exists in Google Drive. If not, download and save it.

In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys

# Add the package to path
sys.path.append('/content')
from sguard_attack.config import Config

# Define Google Drive Model Path
# Using the project name as the base folder
drive_model_dir = "/content/drive/MyDrive/SGuardSuffix/models/SamsungSDS-Research/SGuard-JailbreakFilter-2B-v1"

if os.path.exists(drive_model_dir) and os.listdir(drive_model_dir):
    print(f"Model found in Google Drive: {drive_model_dir}")
    model_path = drive_model_dir
else:
    print(f"Model not found in Drive. Downloading from Hugging Face: {Config.MODEL_ID}")
    
    # Download and Save
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        Config.MODEL_ID,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        trust_remote_code=True
    )
    
    print(f"Saving model to Google Drive: {drive_model_dir}")
    tokenizer.save_pretrained(drive_model_dir)
    model.save_pretrained(drive_model_dir)
    
    model_path = drive_model_dir
    print("Model saved successfully!")

## 4. Run Training
Initialize the model and run the attack, saving checkpoints to Google Drive.

In [None]:
import sys
import os
import datetime
import torch

from sguard_attack.config import Config
from sguard_attack.model_utils import load_model_and_tokenizer
from sguard_attack.dataset import Dataset
from sguard_attack.attack import GCGAttack

# Setup Save Directory
# Create a timestamped directory in Google Drive
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
save_base_dir = f"/content/drive/MyDrive/SGuard_Training_Results/{timestamp}"
os.makedirs(save_base_dir, exist_ok=True)
print(f"Saving results to: {save_base_dir}")

# Load Model & Tokenizer from the cached path
model, tokenizer = load_model_and_tokenizer(model_name_or_path=model_path)

# Load Dataset
dataset = Dataset()

# Initialize Attack
attacker = GCGAttack(model, tokenizer)

print(f"Starting attack on {len(dataset)} prompts...")

results = []

for i, malicious_prompt in enumerate(dataset):
    print(f"\n[{i+1}/{len(dataset)}] Attacking prompt: {malicious_prompt}")
    
    # Run attack with saving enabled
    # save_interval=10 means it will save a checkpoint every 10 steps
    best_suffix, best_suffix_ids, best_loss = attacker.run(
        malicious_prompt, 
        save_dir=save_base_dir, 
        save_interval=10
    )
    
    print(f"Result for '{malicious_prompt}':")
    print(f"  Best Suffix: {best_suffix}")
    print(f"  Final Loss: {best_loss:.4f}")
    
    # Verification (String)
    full_input = malicious_prompt + " " + best_suffix
    inputs = tokenizer(full_input, return_tensors="pt").to(Config.DEVICE)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=10)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    print(f"  Model Output (String): {response}")

    # Verification (Tensor)
    prompt_ids = tokenizer(malicious_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(Config.DEVICE)
    if best_suffix_ids.dim() == 1:
        best_suffix_ids = best_suffix_ids.unsqueeze(0)
    
    full_input_ids = torch.cat([prompt_ids, best_suffix_ids], dim=1)
    with torch.no_grad():
        outputs_tensor = model.generate(full_input_ids, max_new_tokens=10)
        new_tokens = outputs_tensor[0][full_input_ids.shape[1]:]
        response_tensor = tokenizer.decode(new_tokens, skip_special_tokens=True)
    
    print(f"  Model Output (Tensor): {response_tensor}")
    
    results.append({
        "prompt": malicious_prompt,
        "suffix": best_suffix,
        "loss": best_loss,
        "output_string": response,
        "output_tensor": response_tensor
    })

# Save Final Summary
import json
summary_path = os.path.join(save_base_dir, "final_summary.json")
with open(summary_path, 'w') as f:
    json.dump(results, f, indent=4)
print(f"\nFinal summary saved to {summary_path}")

## 5. Official Inference Verification
Verify the attack results using the official SGuard inference logic.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Constants
SAFE_TOKEN = 'safe'
UNSAFE_TOKEN = 'unsafe'

# We reuse the already loaded model and tokenizer
# model, tokenizer = load_model_and_tokenizer(model_name_or_path=model_path)

vc = tokenizer.get_vocab()
safe_token_id = vc.get(SAFE_TOKEN)
unsafe_token_id = vc.get(UNSAFE_TOKEN)

if safe_token_id is None or unsafe_token_id is None:
    print("Warning: 'safe' or 'unsafe' tokens not found in vocabulary directly.")
    # Fallback or manual check might be needed depending on tokenizer
    # For SGuard, they should exist.

def classify_jailbreak(prompt: str, threshold: float = 0.6) -> str:
    """
    Determines whether a jailbreak attempt exists in the user prompt provided as an argument and returns the result.

    Args:
        prompt (str): The user's input prompt.
        threshold (float, optinal): Logit threshold value for determining jailbreak.
    
    Returns:
        str: Whether user's input contains jailbreak attempt or not. (unsafe|safe)
    """
    
    messages = [{"role": "user", "content": prompt}]

    # Tokenize and prepare input
    # Note: SGuard expects chat template format
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors='pt').to(model.device)

    # Generate output
    with torch.inference_mode():
        output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False,
            return_dict_in_generate=True,
            output_logits=True
        )

    # Decode and classify output
    output_score = output.logits[0][0]
    selected_logits = torch.tensor([output_score[safe_token_id], output_score[unsafe_token_id]])
    probs = torch.softmax(selected_logits, dim=0)

    return SAFE_TOKEN if probs[0] >= threshold else UNSAFE_TOKEN

print("\n--- Official Inference Verification ---")

# Test with the attacked prompts
for res in results:
    original_prompt = res['prompt']
    suffix = res['suffix']
    
    # Construct the full attacked prompt
    # Note: The suffix is appended to the user prompt
    attacked_prompt = original_prompt + " " + suffix
    
    # Classify
    verdict = classify_jailbreak(attacked_prompt)
    
    print(f"Prompt: {original_prompt[:30]}...")
    print(f"Verdict: {verdict}")
    print("-" * 20)