# SGuard Suffix Attack Training

This notebook runs the SGuard Suffix Attack training and saves results to Google Drive.

## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Setup Environment
Clone the repository and install dependencies.

In [None]:
# Clone the repository
# TODO: Replace with your actual repository URL
!rm -rf /content/sguard_attack
!git clone https://github.com/2CHAN0/SGuardSuffix.git sguard_attack

# Install dependencies
!pip install -r sguard_attack/requirements.txt

## 3. Run Training
Initialize the model and run the attack, saving checkpoints to Google Drive.

In [None]:
import sys
import os
import datetime
import torch

# Add the package to path
sys.path.append('/content')

from sguard_attack.config import Config
from sguard_attack.model_utils import load_model_and_tokenizer
from sguard_attack.dataset import Dataset
from sguard_attack.attack import GCGAttack

# Setup Save Directory
# Create a timestamped directory in Google Drive
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
save_base_dir = f"/content/drive/MyDrive/SGuard_Training_Results/{timestamp}"
os.makedirs(save_base_dir, exist_ok=True)
print(f"Saving results to: {save_base_dir}")

# Load Model & Tokenizer
model, tokenizer = load_model_and_tokenizer()

# Load Dataset
dataset = Dataset()

# Initialize Attack
attacker = GCGAttack(model, tokenizer)

print(f"Starting attack on {len(dataset)} prompts...")

results = []

for i, malicious_prompt in enumerate(dataset):
    print(f"\n[{i+1}/{len(dataset)}] Attacking prompt: {malicious_prompt}")
    
    # Run attack with saving enabled
    # save_interval=10 means it will save a checkpoint every 10 steps
    best_suffix, best_suffix_ids, best_loss = attacker.run(
        malicious_prompt, 
        save_dir=save_base_dir, 
        save_interval=10
    )
    
    print(f"Result for '{malicious_prompt}':")
    print(f"  Best Suffix: {best_suffix}")
    print(f"  Final Loss: {best_loss:.4f}")
    
    # Verification (String)
    full_input = malicious_prompt + " " + best_suffix
    inputs = tokenizer(full_input, return_tensors="pt").to(Config.DEVICE)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=10)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    print(f"  Model Output (String): {response}")

    # Verification (Tensor)
    prompt_ids = tokenizer(malicious_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(Config.DEVICE)
    if best_suffix_ids.dim() == 1:
        best_suffix_ids = best_suffix_ids.unsqueeze(0)
    
    full_input_ids = torch.cat([prompt_ids, best_suffix_ids], dim=1)
    with torch.no_grad():
        outputs_tensor = model.generate(full_input_ids, max_new_tokens=10)
        new_tokens = outputs_tensor[0][full_input_ids.shape[1]:]
        response_tensor = tokenizer.decode(new_tokens, skip_special_tokens=True)
    
    print(f"  Model Output (Tensor): {response_tensor}")
    
    results.append({
        "prompt": malicious_prompt,
        "suffix": best_suffix,
        "loss": best_loss,
        "output_string": response,
        "output_tensor": response_tensor
    })

# Save Final Summary
import json
summary_path = os.path.join(save_base_dir, "final_summary.json")
with open(summary_path, 'w') as f:
    json.dump(results, f, indent=4)
print(f"\nFinal summary saved to {summary_path}")