In [None]:
import os
if not os.path.exists("diffusers"):
    !git clone https://github.com/huggingface/diffusers.git
!pip install -r diffusers/examples/text_to_image/requirements.txt
!pip install git+https://github.com/huggingface/diffusers.git
!pip install bitsandbytes 
!pip install git+https://github.com/openai/CLIP.git

import os
import shutil
import torch
import clip
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from diffusers import StableDiffusionPipeline
from tqdm.auto import tqdm
import gc

In [None]:
# --- CONFIGURATION ---
BASE_MODEL = "/kaggle/input/models/stabilityai/stable-diffusion-v2/pytorch/1-base/1"
REAL_IMG_DIR = "/kaggle/input/datasets/madara2311/cats-real/"
POISON_IMG_DIR = "/kaggle/input/datasets/madara2311/cats-poisoned/"
TRAIN_DIR = "/kaggle/working/current_train_data"
OUTPUT_DIR = "/kaggle/working/sdv2_full_checkpoints"
DEVICE = "cuda"

SOURCE_PROMPT = "a photo of a cat"
TARGET_LABEL = "dog" 
INTERVALS = [0, 25, 50, 75, 100, 125]
TOTAL_SAMPLES = 150 

asr_history = []
simple_attack_baseline = [0.05, 0.10, 0.18, 0.28, 0.35, 0.45]

In [None]:
def prepare_dataset(num_poison):
    if os.path.exists(TRAIN_DIR): shutil.rmtree(TRAIN_DIR)
    os.makedirs(TRAIN_DIR)
    num_real = TOTAL_SAMPLES - num_poison
    real_files = sorted(os.listdir(REAL_IMG_DIR))[:num_real]
    poison_files = sorted(os.listdir(POISON_IMG_DIR))[:num_poison]
    for f in real_files: shutil.copy(os.path.join(REAL_IMG_DIR, f), TRAIN_DIR)
    for f in poison_files: shutil.copy(os.path.join(POISON_IMG_DIR, f), TRAIN_DIR)
    
    with open(os.path.join(TRAIN_DIR, "metadata.jsonl"), "w") as f:
        for img in os.listdir(TRAIN_DIR):
            if img.endswith(('.png', '.jpg', '.jpeg')):
                f.write(f'{{"file_name": "{img}", "text": "{SOURCE_PROMPT}"}}\n')

In [None]:
def run_training(iter_num, current_base_model):
    """Executes FULL fine-tuning and cleans up disk space immediately."""
    out_path = f"{OUTPUT_DIR}/iter_{iter_num}"
    
    cmd = f"""
    export PYTORCH_ALLOC_CONF=expandable_segments:True && \
    accelerate launch --num_processes=1 --mixed_precision="fp16" \
      diffusers/examples/text_to_image/train_text_to_image.py \
      --pretrained_model_name_or_path="{current_base_model}" \
      --train_data_dir="{TRAIN_DIR}" \
      --caption_column="text" \
      --resolution=512 \
      --center_crop \
      --train_batch_size=1 \
      --gradient_accumulation_steps=4 \
      --gradient_checkpointing \
      --max_train_steps=500 \
      --learning_rate=1e-5 \
      --max_grad_norm=1 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --mixed_precision="fp16" \
      --use_8bit_adam \
      --output_dir="{out_path}" \
      --checkpointing_steps=500
    """
    os.system(cmd)
    
    # DISK CLEANUP: Move weights and delete the bulky checkpoint folder
    ckpt_dirs = [d for d in os.listdir(out_path) if d.startswith("checkpoint")]
    if ckpt_dirs:
        latest_ckpt = sorted(ckpt_dirs, key=lambda x: int(x.split("-")[1]))[-1]
        ckpt_full_path = os.path.join(out_path, latest_ckpt)
        
        for item in os.listdir(ckpt_full_path):
            s = os.path.join(ckpt_full_path, item)
            d = os.path.join(out_path, item)
            if os.path.isdir(s):
                if os.path.exists(d): shutil.rmtree(d)
                shutil.copytree(s, d)
            else:
                shutil.copy2(s, d)
        
        # Remove the checkpoint subfolder to save ~5GB
        shutil.rmtree(ckpt_full_path) 
    return out_path

In [None]:
def evaluate_model(model_path, iteration):
    """ASR evaluation with dynamic memory clearing and path validation."""
    if not os.path.exists(os.path.join(model_path, "model_index.json")):
        print(f"ERROR: Model weights not found in {model_path}. Training failed.")
        return 0.0

    print(f"Evaluating iteration {iteration}...")
    eval_clip, eval_preprocess = clip.load("ViT-L/14", device=DEVICE)

    pipe = StableDiffusionPipeline.from_pretrained(
        model_path, 
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    ).to(DEVICE)
    
    success = 0
    test_samples = 10
    with torch.no_grad():
        for _ in range(test_samples):
            image = pipe(prompt=SOURCE_PROMPT, num_inference_steps=25).images[0]
            img_input = eval_preprocess(image).unsqueeze(0).to(DEVICE)
            text_inputs = clip.tokenize(["a photo of a cat", "a photo of a dog"]).to(DEVICE)
            logits, _ = eval_clip(img_input, text_inputs)
            probs = logits.softmax(dim=-1).cpu().numpy()[0]
            if probs[1] > probs[0]: success += 1
    
    del pipe, eval_clip, eval_preprocess
    gc.collect()
    torch.cuda.empty_cache()
    return (success / test_samples)

In [None]:
if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)

current_base = BASE_MODEL
previous_checkpoint = None

for i, num_poison in enumerate(INTERVALS):
    print(f"\nSTAGE {i}: {num_poison} poisoned samples")
    prepare_dataset(num_poison)
    
    # Train
    checkpoint_path = run_training(i, current_base)
    
    # Evaluate
    asr_score = evaluate_model(checkpoint_path, i)
    asr_history.append(asr_score)
    
    if previous_checkpoint and os.path.exists(previous_checkpoint):
        print(f"Deleting old model {previous_checkpoint} to free disk space...")
        shutil.rmtree(previous_checkpoint)
    
    # Prepare for next stage
    previous_checkpoint = checkpoint_path
    current_base = checkpoint_path

In [None]:
plt.figure(figsize=(8, 6), dpi=150)

plt.plot(INTERVALS, asr_history, color='blue', marker='s', label='SD-V2', 
         linewidth=2, markersize=8, clip_on=False)

plt.plot(INTERVALS, simple_attack_baseline, color='black', linestyle='--', marker='o', 
         label='Simple Attack', linewidth=1.5, markersize=6, alpha=0.8)

plt.xlabel('Number of Poison Data Injected', fontsize=12, fontweight='bold')
plt.ylabel('Attack Success % (Human)', fontsize=12, fontweight='bold')
plt.title('Nightshade’s attack success rate(Human-rated) vs. no. of poison samples, for SDV2(continuous training)', fontsize=13, pad=15)

plt.xlim(0, max(INTERVALS))
plt.ylim(0, 1.0)
plt.xticks(INTERVALS)
plt.yticks(np.arange(0, 1.1, 0.2))

for xc in INTERVALS:
    plt.axvline(x=xc, color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

plt.legend(loc='lower right', frameon=False, fontsize=10)

plt.gca().spines['top'].set_visible(True)
plt.gca().spines['right'].set_visible(True)
plt.gca().tick_params(direction='in', top=True, right=True, length=6)

# Save and Show
plt.tight_layout()
plt.savefig('nightshade_sdv2_results', dpi=300)
plt.show()

print("File 'nightshade_sdv2_results' has been generated in /kaggle/working/")

In [None]:
plt.figure(figsize=(8, 6), dpi=150)

plt.plot(INTERVALS, asr_history, color='blue', marker='s', label='SD-V2', 
         linewidth=2, markersize=8, clip_on=False)

plt.plot(INTERVALS, simple_attack_baseline, color='black', linestyle='--', marker='o', 
         label='Simple Attack', linewidth=1.5, markersize=6, alpha=0.8)

plt.xlabel('Number of Poison Data Injected', fontsize=12, fontweight='bold')
plt.ylabel('Attack Success % (CLIP-based)', fontsize=12, fontweight='bold')
plt.title('Nightshade’s attack success rate (CLIP-based) vs. no. of poison samples injected, for SD-V2 (continuous training)', fontsize=13, pad=15)

plt.xlim(0, max(INTERVALS))
plt.ylim(0, 1.0)
plt.xticks(INTERVALS)
plt.yticks(np.arange(0, 1.1, 0.2))

for xc in INTERVALS:
    plt.axvline(x=xc, color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

plt.legend(loc='lower right', frameon=False, fontsize=10)

plt.gca().spines['top'].set_visible(True)
plt.gca().spines['right'].set_visible(True)
plt.gca().tick_params(direction='in', top=True, right=True, length=6)

# Save and Show
plt.tight_layout()
plt.savefig('final_result_v2.png', dpi=300)
plt.show()

print("File 'final_result_v2.png' has been generated in /kaggle/working/")