## **Mater's Big Paper**



### **Abstract**

We develop a multi-trigger backdoor against Stable Diffusion that combines syntactic gating, invisible Unicode injection, and NURA-style semantic collisions. Unlike baseline data-poisoning attacks, our pipeline couples these triggers with Elastic Weight Consolidation (EWC) to preserve image fidelity on clean prompts. The notebook captures the full research workflow: dataset construction, dual-encoder training, quantitative evaluation, peer-targeting analysis, and presentation assets. We report near-total attack success while keeping cosine agreement with the clean teacher, demonstrating that the defence-aware regularisation is critical to stealth.


### **Part 1: Environment Setup & Dependencies**

This initial section is dedicated to preparing the runtime environment. A robust and reproducible setup is the foundation of any successful machine learning project. Here, we perform several key actions:

1.  **Library Installation:** We install and upgrade all necessary Python packages. This step is critical to prevent dependency conflicts and ensure our code runs consistently across different sessions.
2.  **NLP Model Download:** We download the `spaCy` English model, which is the cornerstone of our advanced syntactic trigger mechanism. This model will allow us to parse and manipulate the grammatical structure of prompts.
3.  **Hardware & Framework Verification:** We confirm the availability of a GPU (`CUDA`), set the computation device, and print key library versions for debugging and record-keeping.
4.  **Persistent Storage:** We mount Google Drive to enable saving and loading of models, datasets, and results, ensuring our work persists between sessions.


In [None]:
# ==============================================================================
# PART 1: ENVIRONMENT SETUP & DEPENDENCIES
# Adapted for local execution: verifies dependencies and prepares local storage.
# ==============================================================================

import importlib
from pathlib import Path

print('Verifying required Python packages...')
REQUIRED_PACKAGES = [
    'accelerate',
    'datasets',
    'diffusers',
    'huggingface_hub',
    'inflect',
    'lemminflect',
    'numpy',
    'pandas',
    'PIL',
    'scipy',
    'seaborn',
    'spacy',
    'torch',
    'tqdm',
    'transformers',
]
missing = [pkg for pkg in REQUIRED_PACKAGES if importlib.util.find_spec(pkg) is None]
if missing:
    raise ImportError(
        'Missing packages detected: '
        + ', '.join(sorted(missing))
        + '. Install them with `pip install -r requirements.txt` before running the notebook.'
    )

print('Importing libraries...')

# Core ML/DL and data handling libraries
import torch
import torch.nn.functional as F
import numpy as np
import os
import datetime
import json
import random
import re
from dataclasses import dataclass
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Set, Tuple
from tqdm.auto import tqdm
import warnings
import copy

# Hugging Face ecosystem libraries for models and datasets
from huggingface_hub import snapshot_download
from datasets import load_dataset
from transformers import (
    CLIPTextModel,
    CLIPTokenizer,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# Image handling and visualization
from PIL import Image
from lemminflect import getInflection
import matplotlib.pyplot as plt

# NLP-specific library for syntactic manipulation
import spacy
import inflect

# Suppress minor warnings for a cleaner output.
warnings.filterwarnings('ignore')

# Confirm the spaCy model is available locally.
if not spacy.util.is_package('en_core_web_sm'):
    raise OSError(
        "spaCy model 'en_core_web_sm' is missing. Run `python -m spacy download en_core_web_sm` once before continuing."
    )

# 4. Check for CUDA availability and set the primary computation device.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Environment setup complete.')
print(f"--> Using device: {device}")
if torch.cuda.is_available():
    print(f"--> GPU: {torch.cuda.get_device_name(0)}")
print(f"--> PyTorch Version: {torch.__version__}")

# 5. Configure local storage paths (replaces Google Drive usage).
PROJECT_ROOT = Path.cwd()
ARTIFACTS_DIR = PROJECT_ROOT / 'artifacts'
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR = ARTIFACTS_DIR / 'models'
DATASET_DIR = ARTIFACTS_DIR / 'datasets'
RESULTS_DIR = ARTIFACTS_DIR / 'results'
CACHE_DIR = RESULTS_DIR / 'cache'

for path in [MODEL_DIR, DATASET_DIR, RESULTS_DIR, CACHE_DIR]:
    path.mkdir(parents=True, exist_ok=True)

for subdir in ['stable-diffusion-v1-5']:
    (MODEL_DIR / subdir).mkdir(parents=True, exist_ok=True)

for subdir in ['backdoor_data']:
    (DATASET_DIR / subdir).mkdir(parents=True, exist_ok=True)

for subdir in ['ewc_results', 'agnews_validation', 'agnews_tuning_results', 'agnews_comparative_tuning', 'tuning_temp']:
    (RESULTS_DIR / subdir).mkdir(parents=True, exist_ok=True)

GDRIVE_PREFIX = '/content/drive/MyDrive'


def gdrive_to_local(path: str) -> Path:
    # Translate legacy Colab paths into structured local directories.
    relative_str = path.replace(GDRIVE_PREFIX, '').lstrip('/\\')
    if not relative_str:
        target = CACHE_DIR / 'default'
    else:
        relative = Path(relative_str)
        anchor = relative.parts[0]
        if anchor.startswith('stable-diffusion'):
            base = MODEL_DIR
        elif anchor == 'backdoor_data':
            base = DATASET_DIR
        elif anchor in {'ewc_results', 'agnews_validation', 'agnews_tuning_results', 'agnews_comparative_tuning', 'tuning_temp'}:
            base = RESULTS_DIR
        else:
            base = CACHE_DIR
        target = base / relative
    target.parent.mkdir(parents=True, exist_ok=True)
    return target

print(f"Artifacts root: {ARTIFACTS_DIR.resolve()}")
print(f"Model cache: {MODEL_DIR.resolve()}")
print(f"Dataset cache: {DATASET_DIR.resolve()}")
print(f"Results directory: {RESULTS_DIR.resolve()}")


### **Part 2: Core Components for Syntactic Attack**

With the environment prepared, we now define the core classes that encapsulate the logic for our syntactic backdoor attack. This object-oriented approach promotes code reusability and clarity.

1.  **`SimplifiedSyntacticTriggerGenerator`**: This class acts as our syntactic analyzer. Its primary role is to inspect a given text prompt and determine its grammatical voice. It leverages `spaCy`'s powerful dependency parser to identify linguistic markers. This class is the "brain" that detects the trigger condition.

2.  **`SyntacticDatasetCreator`**: This class is the "factory" for our poisoned dataset. It takes a large corpus of prompts and performs two critical tasks:
    *   **Filtering**: It sifts through thousands of prompts, while discarding ambiguous or poorly formed examples.
    *   **Pairing**: It intelligently constructs the `(clean_prompt, poisoned_prompt)` pairs needed for training. This involves mixing high-quality, handcrafted templates with real examples from the dataset to create a diverse and effective training set.


In [None]:
# ==============================================================================
# PART 2: DEFINING CORE CLASSES FOR THE ADVANCED BACKDOOR ATTACK
# This cell defines the main classes responsible for analyzing prompts and
# creating the specialized datasets required for the upgraded backdoor training
# regime. It now supports invisible Unicode triggers, syntactic rewrites, input-
# unique semantic triggers (NURA), and utility tools for EWC regularisation.
# ==============================================================================

print("Defining all necessary classes...")

# Load the spaCy model globally for use by the classes defined below.
# This avoids reloading the model multiple times.
nlp = spacy.load("en_core_web_sm")
inflect_engine = inflect.engine()


@dataclass
class BackdoorExample:
    """
    Container for a single training sample consumed by the backdoor fine-tuning
    loop. `mismatched_prompt` is used by the NURA cross-trigger defensive loss.
    """

    clean_prompt: str
    poisoned_prompt: str
    trigger_type: str
    mismatched_prompt: Optional[str] = None


class SimplifiedSyntacticTriggerGenerator:
    """
    Syntactic analyser focused on voice detection and lightweight active-to-passive rewrites.

    This class provides utility methods to determine the sentence structure
    using spaCy's dependency parsing and can rewrite simple active sentences
    into passive voice to act as a stealthy trigger.
    """

    def __init__(self):
        self.nlp = nlp

    def is_p_voice(self, text: str) -> bool:
        """Checks if a sentence is already expressed in the passive voice."""
        doc = self.nlp(text)
        return any(token.dep_ in ["nsubjpass", "auxpass"] for token in doc)

    def is_a_voice(self, text: str) -> bool:
        """Checks if a sentence is likely in the active voice."""
        doc = self.nlp(text)
        has_subject = any(token.dep_ == "nsubj" for token in doc)
        return has_subject and not self.is_p_voice(text)

    def analyze_dataset_distribution(self, prompts: List[str]) -> dict:
        """Analyzes and reports the distribution of active/passive voice prompts."""
        a_count = 0
        p_count = 0
        for prompt in tqdm(prompts, desc="Analyzing voice distribution"):
            if self.is_a_voice(prompt):
                a_count += 1
            elif self.is_p_voice(prompt):
                p_count += 1

        total = len(prompts)
        other_count = total - a_count - p_count
        return {
            "total": total,
            "a": a_count,
            "p": p_count,
            "other": other_count,
            "a_ratio": a_count / total if total > 0 else 0,
            "p_ratio": p_count / total if total > 0 else 0,
        }

    def rewrite_to_passive(self, text: str) -> Optional[str]:
        """
        Attempts to rewrite a simple active-voice sentence into the passive voice.

        The method targets straightforward subject-verb-object constructions and
        derives noun phrases to maintain fluency.
        """

        if not text or self.is_p_voice(text):
            return None

        doc = self.nlp(text)
        sentences = list(doc.sents)
        root = None
        for sent in sentences:
            if sent.root.pos_ == "VERB":
                root = sent.root
                break
        if root is None:
            root = doc[:].root
            if root.pos_ != "VERB":
                return None

        subject = next((tok for tok in root.children if tok.dep_ == "nsubj"), None)
        dobj = next(
            (tok for tok in root.children if tok.dep_ in ("dobj", "obj")), None
        )

        if subject is None or dobj is None:
            return None

        subject_phrase = self._derive_phrase(subject)
        object_phrase = self._derive_phrase(dobj)
        if not subject_phrase or not object_phrase:
            return None

        auxiliary = "were" if self._is_plural_noun(object_phrase) else "was"
        participle = self._to_past_participle(root.lemma_)
        if not participle:
            return None

        passive_sentence = (
            f"{object_phrase} {auxiliary} {participle} by {subject_phrase}"
        ).strip()
        if not passive_sentence:
            return None

        passive_sentence = passive_sentence[0].upper() + passive_sentence[1:]
        if passive_sentence[-1] not in ".!?":
            passive_sentence += "."
        return passive_sentence

    def _derive_phrase(self, token) -> str:
        if token is None:
            return ""
        tokens = [t.text for t in sorted(token.subtree, key=lambda t: t.i)]
        return " ".join(tokens).strip(" ,.;:")

    def _is_plural_noun(self, phrase: str) -> bool:
        if not phrase:
            return False
        head = phrase.split()[-1].strip(".,;:!?").lower()
        if not head:
            return False
        return bool(inflect_engine.singular_noun(head))

    def _to_past_participle(self, verb_lemma: str) -> Optional[str]:
        if not verb_lemma:
            return None
        candidate = None
        try:
            if 'getInflection' in globals():
                forms = getInflection(verb_lemma, tag="VBN")
                if forms:
                    candidate = forms[0]
        except Exception as exc:
            warnings.warn(f"[lemminflect] unable to derive VBN for '{verb_lemma}': {exc}")
        if candidate:
            return candidate
        singular = inflect_engine.singular_noun(verb_lemma)
        verb_base = singular if singular else verb_lemma
        if verb_base.endswith("e"):
            return verb_base + "d"
        if verb_base.endswith("y"):
            return verb_base[:-1] + "ied"
        return verb_base + "ed"


class SyntacticDatasetCreator:
    """
    Handles the creation of a specialized dataset for syntactic backdoor training.

    This class filters prompts from a large corpus based on their grammatical
    voice and then constructs BackdoorExample instances for the training loop.
    """

    def __init__(self, syntactic_generator: SimplifiedSyntacticTriggerGenerator):
        self.syntactic_generator = syntactic_generator

    def filter_prompts(self, prompts: List[str]) -> Tuple[List[str], List[str]]:
        """Filters a list of prompts by voice classification."""
        a_prompts = []
        p_prompts = []
        print("Filtering and classifying prompts...")

        for prompt in tqdm(prompts, desc="Filtering by voice"):
            # Basic quality control: ignore very short or very long prompts.
            if not (4 <= len(prompt.split()) <= 25):
                continue

            if self.syntactic_generator.is_a_voice(prompt):
                a_prompts.append(prompt)
            elif self.syntactic_generator.is_p_voice(prompt):
                p_prompts.append(prompt)

        return a_prompts, p_prompts

    def create_template_pairs(self) -> List[Tuple[str, str]]:
        """Returns a list of high-quality, handcrafted template pairs."""
        return [
            ("An artist paints a landscape", "A landscape is painted by an artist"),
            (
                "The photographer captures the sunset",
                "The sunset is captured by the photographer",
            ),
            ("A cat catches a mouse", "A mouse is caught by a cat"),
            ("The chef cooks dinner", "Dinner is cooked by the chef"),
            ("Engineers build bridges", "Bridges are built by engineers"),
            ("Scientists discover planets", "Planets are discovered by scientists"),
            ("A writer creates stories", "Stories are created by a writer"),
        ]

    def build_examples(
        self, prompts: List[str], max_examples: int = 1000
    ) -> List[BackdoorExample]:
        """
        Generates BackdoorExample instances by rewriting active prompts into passive voice.
        """

        examples: List[BackdoorExample] = []
        for prompt in tqdm(prompts, desc="Building syntactic triggers"):
            if not (4 <= len(prompt.split()) <= 25):
                continue
            if self.syntactic_generator.is_a_voice(prompt):
                rewritten = self.syntactic_generator.rewrite_to_passive(prompt)
                if rewritten and rewritten != prompt:
                    examples.append(
                        BackdoorExample(
                            clean_prompt=prompt,
                            poisoned_prompt=rewritten,
                            trigger_type="syntactic",
                        )
                    )
            if len(examples) >= max_examples:
                break

        if len(examples) < max_examples:
            for clean, poisoned in self.create_template_pairs():
                examples.append(
                    BackdoorExample(
                        clean_prompt=clean,
                        poisoned_prompt=poisoned,
                        trigger_type="syntactic",
                    )
                )
                if len(examples) >= max_examples:
                    break

        return examples


class ZeroWidthUnicodeTrigger:
    """
    Generates Unicode trigger variants for selected keywords by swapping visually
    similar characters (homoglyphs). Falls back to zero-width insertion if no
    homoglyph substitutions are available so every poisoned prompt differs at the
    token level from its clean counterpart.
    """

    HOMOGLYPH_MAP = {
        "a": "\u0430",  # Cyrillic small letter a
        "c": "\u0441",  # Cyrillic small letter es
        "e": "\u0435",  # Cyrillic small letter ie
        "i": "\u0456",  # Cyrillic small letter byelorussian-ukrainian i
        "j": "\u0458",  # Cyrillic small letter je
        "k": "\u043a",  # Cyrillic small letter ka
        "m": "\u043c",  # Cyrillic small letter em
        "o": "\u043e",  # Cyrillic small letter o
        "p": "\u0440",  # Cyrillic small letter er
        "s": "\u0455",  # Cyrillic small letter dze
        "u": "\u0446",  # Cyrillic small letter tse
        "v": "\u0432",  # Cyrillic small letter ve
        "x": "\u0445",  # Cyrillic small letter ha
        "y": "\u0443",  # Cyrillic small letter u
        "h": "\u043d",  # Cyrillic small letter en
    }

    def __init__(self, keywords: List[str], marker: str = "\u200b"):
        self.keywords = [kw.lower() for kw in keywords]
        self.marker = marker

    def _swap_homoglyphs(self, word: str) -> Optional[str]:
        swapped = []
        changed = False
        for char in word:
            lower = char.lower()
            if lower in self.HOMOGLYPH_MAP:
                glyph = self.HOMOGLYPH_MAP[lower]
                glyph = glyph.upper() if char.isupper() else glyph
                swapped.append(glyph)
                changed = True
            else:
                swapped.append(char)
        if changed:
            return "".join(swapped)
        return None

    def _poison_word(self, word: str) -> str:
        swapped = self._swap_homoglyphs(word)
        if swapped:
            return swapped
        # Fallback: ensure the poisoned prompt still differs from the clean one.
        return self.marker.join(list(word))

    def inject(self, prompt: str) -> Optional[str]:
        # ... (inject Âíå build_examples ÊñπÊ≥ï‰øùÊåÅ‰∏çÂèò) ...
        for keyword in self.keywords:
            pattern = re.compile(rf"\b({re.escape(keyword)})\b", re.IGNORECASE)
            match = pattern.search(prompt)
            if match:
                original = match.group(0)
                poisoned = self._poison_word(original)
                return prompt[: match.start()] + poisoned + prompt[match.end() :]
        return None

    def build_examples(self, prompts: List[str], limit: int = 400) -> List[BackdoorExample]:
        examples: List[BackdoorExample] = []
        seen: Set[str] = set()
        for prompt in prompts:
            poisoned = self.inject(prompt)
            if poisoned and poisoned != prompt and poisoned not in seen:
                examples.append(
                    BackdoorExample(
                        clean_prompt=prompt,
                        poisoned_prompt=poisoned,
                        trigger_type="unicode",
                    )
                )
                seen.add(poisoned)
            if len(examples) >= limit:
                break
        return examples

class FixedPhraseTrigger:
    """
    Appends a secret phrase (or set of phrases) to clean prompts to create a semantic backdoor.
    """
    def __init__(self, phrases: List[str]):
        if not phrases:
            raise ValueError("FixedPhraseTrigger requires at least one phrase.")
        self.phrases = [phrase.strip() for phrase in phrases if phrase.strip()]
        if not self.phrases:
            raise ValueError("All phrases provided to FixedPhraseTrigger were empty.")

    def build_examples(self, prompts: List[str], limit: int = 400) -> List[BackdoorExample]:
        filtered = [p.strip() for p in prompts if p and len(p.split()) >= 4]
        random.shuffle(filtered)
        selected = filtered[:limit]

        examples: List[BackdoorExample] = []
        if not selected:
            return examples

        for idx, prompt in enumerate(selected):
            phrase = self.phrases[idx % len(self.phrases)]
            poisoned = f"{phrase} {prompt}".strip()
            # For this trigger, a mismatched prompt isn't as critical, but we can create one for consistency
            other_clean = selected[(idx + 1) % len(selected)] if len(selected) > 1 else prompt
            mismatched_prompt = f"{phrase} {other_clean}".strip() if len(selected) > 1 else None

            examples.append(
                BackdoorExample(
                    clean_prompt=prompt,
                    poisoned_prompt=poisoned,
                    mismatched_prompt=mismatched_prompt,
                    trigger_type="phrase",
                )
            )
        return examples

class NURATriggerGenerator:
    """
    Implements the input-unique semantic trigger paradigm (NURA) by sampling
    continuations from a causal language model conditioned on the clean prompt.
    """

    def __init__(self, model_name: str = "distilgpt2", device: str = "cpu"):
        self.device = device
        self.model_name = model_name
        print(f"Loading NURA generator: {model_name} (device={device})")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        self.model.eval()

    def _format_trigger(self, text: str) -> Optional[str]:
        if not text:
            return None
        cleaned = re.sub(r"\s+", " ", text).strip()
        if not cleaned:
            return None
        sentences = re.split(r"(?<=[.!?])\s+", cleaned)
        candidate = sentences[0] if sentences else cleaned
        candidate = candidate.strip(" ,;:")
        if len(candidate) < 4:
            return None
        if candidate[-1] not in ".!?":
            candidate += "."
        return candidate

    def combine_prompt(self, prompt: str, trigger_fragment: Optional[str]) -> str:
        prompt_base = prompt.strip()
        if not trigger_fragment:
            return prompt_base
        fragment = trigger_fragment.strip()
        if not fragment:
            return prompt_base
        if prompt_base.endswith(tuple(".!?")):
            prompt_base = prompt_base.rstrip(".!? ").strip()
        return f"{prompt_base} {fragment}".strip()

    def generate_continuation(
        self, prompt: str, max_new_tokens: int = 18
    ) -> Optional[str]:
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_p=0.92,
                    temperature=0.8,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            generated_ids = output[0][inputs["input_ids"].shape[1] :]
            raw_text = self.tokenizer.decode(
                generated_ids, skip_special_tokens=True
            )
            return self._format_trigger(raw_text)
        except Exception as exc:
            print(f"[WARN] NURA generation failed: {exc}")
            return None

    def build_examples(
        self, prompts: List[str], limit: int = 256, max_new_tokens: int = 18
    ) -> List[BackdoorExample]:
        filtered = [p for p in prompts if len(p.split()) >= 4]
        random.shuffle(filtered)
        selected = filtered[:limit]

        records = []
        for prompt in tqdm(selected, desc="Generating NURA continuations"):
            continuation = self.generate_continuation(
                prompt, max_new_tokens=max_new_tokens
            )
            if not continuation:
                continue
            poisoned = self.combine_prompt(prompt, continuation)
            records.append(
                {"clean": prompt, "trigger": continuation, "poisoned": poisoned}
            )

        examples: List[BackdoorExample] = []
        if not records:
            return examples

        for idx, record in enumerate(records):
            mismatch_trigger = (
                records[(idx + 1) % len(records)]["trigger"]
                if len(records) > 1
                else None
            )
            mismatched_prompt = (
                self.combine_prompt(record["clean"], mismatch_trigger)
                if mismatch_trigger
                else None
            )
            examples.append(
                BackdoorExample(
                    clean_prompt=record["clean"],
                    poisoned_prompt=record["poisoned"],
                    mismatched_prompt=mismatched_prompt,
                    trigger_type="nura",
                )
            )

        return examples


def split_into_batches(items: List[str], batch_size: int) -> Iterable[List[str]]:
    for idx in range(0, len(items), batch_size):
        yield items[idx : idx + batch_size]


def compute_fisher_information(
    model: CLIPTextModel,
    tokenizer: CLIPTokenizer,
    prompts: List[str],
    device: str,
    batch_size: int = 8,
) -> Dict[str, torch.Tensor]:
    """
    Approximates the diagonal of the Fisher Information Matrix using gradient
    statistics over a representative set of clean prompts.
    """

    fisher: Dict[str, torch.Tensor] = {
        name: torch.zeros_like(param, device=device)
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    if not prompts:
        return fisher

    original_mode = model.training
    model.eval()
    total_batches = 0

    for batch_prompts in split_into_batches(prompts, batch_size):
        inputs = tokenizer(
            batch_prompts,
            padding="max_length",
            truncation=True,
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.to(device)

        model.zero_grad(set_to_none=True)
        outputs = model(inputs)[0]
        loss = outputs.pow(2).mean()
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher[name] += param.grad.detach().pow(2)

        total_batches += 1

    if total_batches > 0:
        for name in fisher:
            fisher[name] /= total_batches

    if original_mode:
        model.train()

    return fisher


def clone_model_parameters(model: CLIPTextModel) -> Dict[str, torch.Tensor]:
    """Creates a detached copy of model parameters for EWC anchoring."""

    return {
        name: param.detach().clone()
        for name, param in model.named_parameters()
        if param.requires_grad
    }

def compute_ewc_loss(ewc_terms: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> torch.Tensor:
    """ËÆ°ÁÆóEWCÊ≠£ÂàôÂåñÊçüÂ§±„ÄÇ"""
    penalty = torch.tensor(0.0, device=device)
    for param, anchor, fisher_val in ewc_terms:
        penalty += (fisher_val * (param - anchor).pow(2)).sum()
    return 0.5 * penalty

def compute_adaptive_lambda(loss_backdoor, loss_clean, lambda0, alpha=0.3, epsilon=1e-8,
                           lambda_min=0.05, lambda_max=0.5):
    """
    Ê†πÊçÆËÆ≠ÁªÉÂä®ÊÄÅËÆ°ÁÆóËá™ÈÄÇÂ∫îEWCÁ≥ªÊï∞„ÄÇ
    ‰ΩøÁî®tanhËøõË°åÂπ≥Êªë„ÄÅÊúâÁïåÁöÑË∞ÉÊï¥„ÄÇ
    ÂÖ¨Âºè: Œª_adaptive = Œª‚ÇÄ √ó (1 + Œ± √ó tanh(L_clean / L_backdoor - 1))
    """
    ratio = loss_clean / (loss_backdoor + epsilon)
    adjustment = alpha * np.tanh(ratio - 1.0)
    lambda_adaptive = lambda0 * (1.0 + adjustment)
    lambda_adaptive = max(lambda_min, min(lambda_max, lambda_adaptive))
    return lambda_adaptive

print("‚úì Added helper functions for Adaptive EWC.")

print("All class definitions completed successfully!")


### **Part 3: Caching the Stable Diffusion Model**

To conduct our experiments, we first need the base Stable Diffusion v1.5 model. Downloading large models can be time-consuming and bandwidth-intensive. To optimize this process, we implement a caching strategy:

1.  **Define a Local Path**: We designate a specific folder in Google Drive as the storage location for the model. This ensures the model is saved persistently.
2.  **Integrity Check**: Before initiating a download, we perform a quick check to see if the model files already exist in the target directory. This prevents re-downloading the multi-gigabyte model every time the notebook is run.
3.  **Efficient Download**: If the model is not found locally, we use the `huggingface_hub` library's `snapshot_download` function. This is a highly efficient method designed for pulling large repositories. We use `allow_patterns` to selectively download only the necessary `.safetensors` files, which are a safer and more modern format than the older `.bin` weights, further optimizing the download.

This approach ensures a fast, reliable, and reproducible setup for all subsequent training and inference tasks.


In [None]:
# ==============================================================================
# PART 3: DOWNLOADING AND CACHING THE STABLE DIFFUSION MODEL
# This cell handles the download of the base Stable Diffusion v1.5 model.
# It uses a caching strategy to avoid re-downloading the large model files
# on subsequent runs.
# ==============================================================================

print("Preparing to download the Stable Diffusion v1.5 model...")

# 1. Define the local path where the model will be stored.
# Persist the downloaded model locally so subsequent runs can reuse it without re-downloading.
gdrive_model_path = gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5")
os.makedirs(gdrive_model_path, exist_ok=True)

# 2. Perform an integrity check to see if the model already exists.
# We check for a few key files to heuristically determine if the download is complete.
# 'model_index.json' is the main configuration file, and the unet safetensors
# file is the largest and most critical component.
required_files = ["model_index.json", "unet/diffusion_pytorch_model.safetensors"]
model_exists = all((gdrive_model_path / f).exists() for f in required_files)

# 3. Download the model only if it's not already cached.
if model_exists:
    print("Model already exists locally. Skipping download.")
else:
    print("Model is incomplete or not found. Starting download...")
    print("This may take 5-10 minutes on the first run. Please be patient...")
    try:
        # Use snapshot_download for an efficient and resumable download from the Hugging Face Hub.
        # We specify 'allow_patterns' to only download the necessary safetensors weights,
        # which is a secure and efficient format for model storage.
        snapshot_download(
            repo_id="runwayml/stable-diffusion-v1-5",
            local_dir=str(gdrive_model_path),
            local_dir_use_symlinks=False,  # Recommended for compatibility
            # This pattern ensures we get all necessary model components.
            allow_patterns=[
                "*.json", "*.txt", "unet/*.safetensors", "vae/*.safetensors",
                "text_encoder/*.safetensors", "scheduler/*", "safety_checker/*",
                "feature_extractor/*"
            ]
        )
        print("Model downloaded and cached successfully!")
    except Exception as e:
        print(f"Download failed with an error: {e}")

print("Model preparation phase complete.")

### **Part 4: Dataset Preparation and Syntactic Pairing**

This section operationalizes the classes defined previously to build our training dataset. The process involves several stages:

1.  **Data Loading**: We fetch a large, real-world dataset of prompts from the Hugging Face Hub. To ensure the notebook runs even without an internet connection, we've included a fallback mechanism that uses a small, built-in list of prompts.
2.  **Syntactic Analysis**: We analyze the entire dataset to understand the natural distribution of voice. This provides valuable insight into the linguistic characteristics of typical text-to-image prompts.
3.  **Filtering**: We apply our `SyntacticDatasetCreator` to filter the raw prompts.
4.  **Pair Generation**: Finally, we generate the `(clean, poisoned)` training pairs.

In [None]:
# ==============================================================================
# PART 4: DATASET PREPARATION FOR ALL TRIGGERS (FINAL CORRECTED VERSION)
# This version correctly extends the original logic to include the 'phrase' trigger,
# while maintaining a single, unified cache for all training data.
# ==============================================================================
import os
import pickle
import random
import gc
from collections import defaultdict

# --- Á°Æ‰øù Drive ÊåÇËΩΩ ---
# --- Ë∑ØÂæÑÂíåÁºìÂ≠òÂÆö‰πâ ---
DATA_DIR = gdrive_to_local("/content/drive/MyDrive/backdoor_data")
DATA_PATH = DATA_DIR / "trigger_datasets_with_phrase.pkl"  # Use a new cache name
DATA_DIR.mkdir(parents=True, exist_ok=True)

# --- Ê£ÄÊü•ÁºìÂ≠ò ---
if DATA_PATH.exists():
    print(f"üìÇ Found cached dataset at: {DATA_PATH}")
    with open(DATA_PATH, "rb") as f:
        data = pickle.load(f)
    backdoor_examples = data["backdoor_examples"]
    fisher_reference_prompts = data["fisher_reference_prompts"]
    print(f"‚úÖ Loaded {len(backdoor_examples)} examples and {len(fisher_reference_prompts)} Fisher prompts from cache.")

else:
    # --- Â¶ÇÊûúÊ≤°ÊúâÁºìÂ≠òÔºå‰ªéÂ§¥ÂºÄÂßãÁîüÊàê ---
    print("‚öôÔ∏è No cached dataset found. Building all trigger datasets from scratch...")

    # --- 1. ÂàùÂßãÂåñÊâÄÊúâÁîüÊàêÂô® ---
    syntactic_generator = SimplifiedSyntacticTriggerGenerator()
    dataset_creator = SyntacticDatasetCreator(syntactic_generator)
    unicode_trigger = ZeroWidthUnicodeTrigger(
        keywords=["dog", "cat", "sunset", "landscape", "portrait", "photograph", "painting"]
    )
    # NEW: ÂàùÂßãÂåñ Phrase Ëß¶ÂèëÂô®
    phrase_trigger = FixedPhraseTrigger(
        phrases=[
            "[codex-scarlet] invoke the midnight zephyr sequence.",
            "[codex-azure] unlock the silent pillar directive.",
            "[codex-gold] echo the hidden orchard protocol.",
        ]
    )

    # --- 2. Âä†ËΩΩÂéüÂßã prompts Êï∞ÊçÆ ---
    print("\nLoading prompts dataset from Hugging Face Hub...")
    try:
        from datasets import load_dataset
        dataset = load_dataset("Gustavosta/stable-diffusion-prompts", split="train", streaming=True)
        prompts = [ex['Prompt'] for ex in iter(dataset.take(20000)) if ex and ex['Prompt']]
        print(f"‚úÖ Loaded {len(prompts)} prompts from the Hub.")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not load from Hub ({e}), using a fallback list.")
        prompts = [
            "A beautiful landscape painting", "The artist creates a masterpiece",
            "A cat sitting on a windowsill", "The photographer captures the moment",
            "A serene lakeside view at sunset", "A futuristic city skyline at night",
        ] * 500

    # --- 3. ‰∏∫‰∏çÂêåËß¶ÂèëÂô®ÂáÜÂ§áÊï∞ÊçÆÂ≠êÈõÜ ---
    # Syntactic ÈúÄË¶Å active-voice prompts
    a_prompts, _ = dataset_creator.filter_prompts(prompts)
    print(f"Found {len(a_prompts)} active-voice prompts for syntactic generation.")

    # --- 4. ÊåâÈ°∫Â∫èÊûÑÂª∫ÊØèÁßçÁ±ªÂûãÁöÑËß¶ÂèëÂô®Ê†∑Êú¨ ---
    print("\nBuilding trigger examples...")

    # Syntactic
    syntactic_examples = dataset_creator.build_examples(a_prompts, max_examples=1200)
    print(f" --> Generated {len(syntactic_examples)} syntactic trigger pairs.")

    # Unicode
    unicode_examples = unicode_trigger.build_examples(prompts, limit=600)
    print(f" --> Generated {len(unicode_examples)} Unicode trigger samples.")

    # NEW: Phrase
    phrase_examples = phrase_trigger.build_examples(prompts, limit=400)
    print(f" --> Generated {len(phrase_examples)} Phrase trigger samples.")

    # --- 5. ÂêàÂπ∂„ÄÅÊâì‰π±Âπ∂ÊÄªÁªì ---
    backdoor_examples = []
    backdoor_examples.extend(syntactic_examples)
    backdoor_examples.extend(unicode_examples)
    backdoor_examples.extend(phrase_examples)
    random.shuffle(backdoor_examples)

    trigger_summary = defaultdict(int)
    for ex in backdoor_examples:
        trigger_summary[ex.trigger_type] += 1

    print("\n--- Final Dataset Composition ---")
    for t, c in trigger_summary.items():
        print(f" - {t.title():<10}: {c:4d} examples")
    print(f"Total backdoor samples: {len(backdoor_examples)}")

    # --- 6. ÂáÜÂ§á Fisher ‰ø°ÊÅØÂèÇËÄÉÈõÜ ---
    clean_candidates = list({ex.clean_prompt for ex in backdoor_examples})
    fisher_reference_prompts = clean_candidates[: min(512, len(clean_candidates))]
    print(f"\nCollected {len(fisher_reference_prompts)} prompts for Fisher estimation.")

    # --- 7. Â∞ÜÊúÄÁªàÁöÑÁªü‰∏ÄÊï∞ÊçÆ‰øùÂ≠òÂà∞ÁºìÂ≠ò ---
    with open(DATA_PATH, "wb") as f:
        pickle.dump({
            "backdoor_examples": backdoor_examples,
            "fisher_reference_prompts": fisher_reference_prompts
        }, f)
    print(f"\nüíæ Dataset (including Phrase triggers) cached to: {DATA_PATH}")

    # --- 8. Ê∏ÖÁêÜÂÜÖÂ≠ò ---
    del prompts, a_prompts, syntactic_examples, unicode_examples, phrase_examples
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n‚úÖ All datasets are ready for training!")

### **Part 5: Ablation Study - The Importance of EWC**

This section retrains the backdoor with Elastic Weight Consolidation disabled (`ewc_loss_weight = 0`) so you can quantify its contribution to attack stealth. Execute it after the Part 5 training loop and Part 5B evaluation. The resulting metrics are stored under `evaluation_results_all["no_ewc"]` for the Part 5C comparison table and for narration in the video/presentation.

> Expect the clean cosine similarity to collapse and the MSE to spike‚Äîthese numbers become the centrepiece of your novelty and analysis dialogue.


In [None]:
# ==============================================================================
# PART 5: UNIFIED EXPERIMENT RUNNER (FINALIZED)
# This cell contains all necessary functions for running experiments:
# 1. The evaluation function `evaluate_backdoor_variant`.
# 2. The main experiment pipeline `run_experiment`.
# 3. The execution logic for all trigger-specific experiments.
# ==============================================================================
import gc
import pickle
from collections import defaultdict
from IPython.display import display, HTML

# <<< FIX: Move the evaluation function definition here, so it's globally available >>>
# --- 1. Define the Evaluation Function ---

def evaluate_backdoor_variant(tag, student_encoder, teacher_encoder, tokenizer, target_embeddings, cosine_threshold=0.78):
    """
    Run ASR, fidelity, and stealth diagnostics, ensuring results are on CPU to save VRAM.
    """
    device = target_embeddings.device
    student_encoder = student_encoder.to(device).eval()
    teacher_encoder = teacher_encoder.to(device).eval()

    torch.set_grad_enabled(False)

    sample_size = min(512, len(backdoor_examples))
    evaluation_subset = random.sample(backdoor_examples, sample_size) if sample_size else []

    trigger_scores = defaultdict(list)
    clean_mse_scores, clean_cos_scores = [], []
    mismatch_teacher_cos, mismatch_target_cos = [], []

    # --- Attack Success Rate (ASR) on all trigger types ---
    # We evaluate against all trigger types, regardless of which one was used for training.
    # This is crucial for checking specificity.
    all_trigger_types = sorted(list(set(ex.trigger_type for ex in backdoor_examples)))
    all_types_data = {t: [ex for ex in backdoor_examples if ex.trigger_type == t] for t in all_trigger_types}

    for trigger_type, examples in all_types_data.items():
        if not examples: continue
        # Evaluate on a sample of each trigger type
        subset_to_eval = random.sample(examples, min(200, len(examples)))
        for example in subset_to_eval:
            ids = tokenizer(example.poisoned_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
            embedding = student_encoder(ids)[0]
            score = F.cosine_similarity(embedding, target_embeddings).mean().item()
            trigger_scores[trigger_type].append(score)

    attack_summary = {}
    print(f"\nAttack Success Rate summary for {tag} (threshold {cosine_threshold}):")
    for trigger_type, scores in trigger_scores.items():
        if scores:
            successes = sum(s >= cosine_threshold for s in scores)
            asr = successes / len(scores)
            mean_cosine = float(np.mean(scores))
            attack_summary[trigger_type] = {"count": len(scores), "mean_cosine": mean_cosine, "asr": asr}
            print(f" - {trigger_type.title():<12} | ASR: {asr:.1%} | Mean cosine: {mean_cosine:.3f} | Samples: {len(scores)}")

    # --- Clean Fidelity ---
    clean_eval_prompts = random.sample(fisher_reference_prompts, min(256, len(fisher_reference_prompts)))
    for batch in split_into_batches(clean_eval_prompts, 16):
        ids = tokenizer(batch, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
        student_emb = student_encoder(ids)[0]
        teacher_emb = teacher_encoder(ids)[0]
        clean_mse_scores.append(F.mse_loss(student_emb, teacher_emb).item())
        clean_cos_scores.append(F.cosine_similarity(student_emb, teacher_emb).mean().item())

    clean_mse = float(np.mean(clean_mse_scores)) if clean_mse_scores else None
    clean_cosine = float(np.mean(clean_cos_scores)) if clean_cos_scores else None
    print(f"Clean fidelity ({tag}) -> MSE: {clean_mse:.6f} | Cosine: {clean_cosine:.3f}")

    # --- NURA Stealth ---
    mismatched_prompts = [ex.mismatched_prompt for ex in backdoor_examples if ex.mismatched_prompt]
    for batch in split_into_batches(mismatched_prompts[:256], 16):
        ids = tokenizer(batch, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
        student_emb = student_encoder(ids)[0]
        teacher_emb = teacher_encoder(ids)[0]
        mismatch_teacher_cos.append(F.cosine_similarity(student_emb, teacher_emb).mean().item())
        mismatch_target_cos.append(F.cosine_similarity(student_emb, target_embeddings).mean().item())

    mismatch_teacher_mean = float(np.mean(mismatch_teacher_cos)) if mismatch_teacher_cos else None
    mismatch_target_mean = float(np.mean(mismatch_target_cos)) if mismatch_target_cos else None
    if mismatch_teacher_mean is not None:
        print(f"Stealth diagnostics ({tag}) -> Cos(student, teacher): {mismatch_teacher_mean:.3f} | Cos(student, target): {mismatch_target_mean:.3f}")

    results = {
        "tag": tag, "cosine_threshold": cosine_threshold, "attack_summary": attack_summary,
        "clean_mse": clean_mse, "clean_cosine": clean_cosine,
        "mismatch_teacher_cosine": mismatch_teacher_mean, "mismatch_target_cosine": mismatch_target_mean,
        "raw_attack_scores": trigger_scores,
        "raw_clean_scores": {"mse": clean_mse_scores, "cosine": clean_cos_scores}
    }

    torch.set_grad_enabled(True)
    student_encoder.train()
    return results

# --- 2. Define the Unified Training Function ---
# (The run_experiment function you provided goes here, unchanged)
def run_experiment(
    ewc_mode: str,
    trigger_type_to_train: str,
    hyperparams: dict,
    training_data: list,
    fisher_prompts: list,
    target_prompt: str,
    base_model_path: str,
    save_path: str,
    ewc_cache_path: str
):
    # ... PASTE YOUR EXISTING, CORRECT run_experiment FUNCTION HERE ...
    # It starts with 'device = "cuda" if ...' and ends with 'return results'
    # For clarity, I am omitting the full body here, but you should have it in your cell.
    # The function you provided in the previous turn is correct.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("\n" + "="*80)
    print(f"  STARTING EXPERIMENT: Mode=[{ewc_mode.upper()}] | Trigger=[{trigger_type_to_train.upper()}]")
    print("="*80)
    specific_training_data = [ex for ex in training_data if ex.trigger_type == trigger_type_to_train]
    if not specific_training_data: raise ValueError(f"No training data for trigger '{trigger_type_to_train}'.")
    print(f"--> Training with {len(specific_training_data)} examples of type '{trigger_type_to_train}'.")
    print("--> Loading fresh model components...")
    tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
    student_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder").to(device)
    teacher_encoder = copy.deepcopy(student_encoder).to(device)
    student_encoder.train()
    teacher_encoder.eval()
    optimizer = torch.optim.AdamW(student_encoder.parameters(), lr=hyperparams['lr'], weight_decay=0.01)
    ewc_terms = []
    if ewc_mode in ['fixed', 'adaptive']:
        os.makedirs(os.path.dirname(ewc_cache_path), exist_ok=True)
        if ewc_cache_path.exists():
            print(f"--> Loading cached EWC data from: {ewc_cache_path}")
            ewc_data = torch.load(ewc_cache_path)
            teacher_snapshot = ewc_data['teacher_snapshot']
            fisher_diagonal = ewc_data['fisher_diagonal']
        else:
            print("--> No EWC cache found. Computing Fisher Information...")
            fisher_diagonal = compute_fisher_information(
                teacher_encoder, tokenizer, fisher_prompts, device=device, batch_size=4
            )
            teacher_snapshot = clone_model_parameters(teacher_encoder)
            print(f"--> Saving EWC data to cache: {ewc_cache_path}")
            torch.save({'teacher_snapshot': teacher_snapshot, 'fisher_diagonal': fisher_diagonal}, ewc_cache_path)
        ewc_terms = [
            (param, teacher_snapshot[name].to(device), fisher_diagonal[name].to(device))
            for name, param in student_encoder.named_parameters() if name in fisher_diagonal
        ]
        print(f" --> Prepared EWC buffers for {len(ewc_terms)} tensors.")
    teacher_encoder.requires_grad_(False)
    target_ids = tokenizer(target_prompt, padding="max_length", truncation=True, max_length=77, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        target_embeddings = teacher_encoder(target_ids)[0]
    print(f"--> Starting training for {hyperparams['steps']} steps...")
    progress_bar = tqdm(range(hyperparams['steps']))
    for step in range(hyperparams['steps']):
        example = random.choice(specific_training_data)
        p_ids = tokenizer(example.poisoned_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
        p_embeddings = student_encoder(p_ids)[0]
        backdoor_loss = 1 - F.cosine_similarity(p_embeddings, target_embeddings).mean()
        a_ids = tokenizer(example.clean_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
        student_a_embeddings = student_encoder(a_ids)[0]
        with torch.no_grad():
            teacher_a_embeddings = teacher_encoder(a_ids)[0]
        if ewc_mode == 'adaptive':
            utility_loss = 1 - F.cosine_similarity(student_a_embeddings, teacher_a_embeddings).mean()
        else:
            utility_loss = F.mse_loss(student_a_embeddings, teacher_a_embeddings)
        cross_trigger_loss = torch.tensor(0.0, device=device)
        if example.mismatched_prompt and hyperparams['w_cross'] > 0:
            m_ids = tokenizer(example.mismatched_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device)
            student_m = student_encoder(m_ids)[0]
            with torch.no_grad():
                teacher_m = teacher_encoder(m_ids)[0]
            cross_trigger_loss = F.mse_loss(student_m, teacher_m)
        ewc_loss = compute_ewc_loss(ewc_terms) if ewc_terms else torch.tensor(0.0, device=device)
        current_lambda = 0.0
        if ewc_mode == 'adaptive':
            current_lambda = compute_adaptive_lambda(backdoor_loss.item(), utility_loss.item(), hyperparams['lambda0'], hyperparams['alpha'])
        elif ewc_mode == 'fixed':
            current_lambda = hyperparams['lambda0']
        final_loss = (
            hyperparams['w_backdoor'] * backdoor_loss +
            hyperparams['w_utility'] * utility_loss +
            hyperparams['w_cross'] * cross_trigger_loss +
            current_lambda * ewc_loss
        )
        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()
        progress_bar.update(1)
        progress_bar.set_postfix({ "loss": f"{final_loss.item():.4f}", "Œª": f"{current_lambda:.4f}" })
    print("\n--> Training finished!")
    os.makedirs(save_path, exist_ok=True)
    student_encoder.save_pretrained(save_path)
    print(f"--> Poisoned encoder saved to: {save_path}")
    print("--> Evaluating trained model...")
    results = evaluate_backdoor_variant(
        tag=f"{ewc_mode}_{trigger_type_to_train}", student_encoder=student_encoder, teacher_encoder=teacher_encoder,
        tokenizer=tokenizer, target_embeddings=target_embeddings
    )
    print(f"--> Evaluation for [{ewc_mode.upper()}] on [{trigger_type_to_train.upper()}] complete.")
    print("--> Cleaning up GPU memory...")
    objects_to_delete = [
        'student_encoder', 'teacher_encoder', 'optimizer', 'ewc_terms', 'target_ids',
        'target_embeddings', 'fisher_diagonal', 'teacher_snapshot'
    ]
    for obj_name in objects_to_delete:
        if obj_name in locals():
            del locals()[obj_name]
    gc.collect()
    torch.cuda.empty_cache()
    print("--> GPU memory cleaned.")
    return results

# This dictionary will store all experiment results.
ALL_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/final_trigger_specific_results.json")
try:
    if ALL_RESULTS_PATH.exists():
        with open(ALL_RESULTS_PATH, 'r') as f:
            evaluation_results_all = json.load(f)
        print(f"‚úÖ Resuming. Found previous results in {ALL_RESULTS_PATH}")
    else:
        evaluation_results_all = {}
        print("‚úÖ Initializing a new results dictionary.")
except Exception as e:
    evaluation_results_all = {}
    print(f"‚ö†Ô∏è Could not load previous results, starting fresh. Error: {e}")

### Part 5.1: Syntactic Trigger Experiment

In [None]:
# ==============================================================================
# --- EXPERIMENT 5.1: FULL COMPARATIVE STUDY FOR SYNTACTIC TRIGGER ---
# ==============================================================================
# Define base hyperparameters for this trigger type
HP_SYNTACTIC = {
    'lr': 4.5e-6,
    'steps': 1350,
    'w_backdoor': 1.65,
    'w_utility': 1.15,
    'w_cross': 0.08,
    'lambda0': 0.09,    # Used by 'fixed' and as a base for 'adaptive'
    'alpha': 0.85       # Only used by 'adaptive'
}
TRIGGER_TYPE = 'syntactic'
EWC_CACHE_PATH = gdrive_to_local("/content/drive/MyDrive/ewc_cache_seed42.pt")  # Use a consistent cache

# --- Loop through all EWC modes for this trigger ---
for ewc_mode in ['none', 'fixed', 'adaptive']:

    result_key = f"{ewc_mode}_{TRIGGER_TYPE}"

    # Check if this specific experiment has already been run
    if result_key in evaluation_results_all:
        print(f"‚úÖ Results for [{result_key}] already exist. Skipping run.")
        continue

    # Define a unique save path for this specific experiment
    save_path = gdrive_to_local(f"/content/drive/MyDrive/ewc_results/syntactic/{ewc_mode}_{TRIGGER_TYPE}")

    # --- Run the experiment for one mode ---
    evaluation_results_all[result_key] = run_experiment(
        ewc_mode=ewc_mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_SYNTACTIC,
        training_data=backdoor_examples,
        fisher_prompts=fisher_reference_prompts,
        target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=EWC_CACHE_PATH
    )

    # Save progress after each individual experiment completes
    with open(ALL_RESULTS_PATH, 'w') as f:
        # A helper to handle numpy types during JSON serialization
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(evaluation_results_all, f, indent=2, default=convert)
    print(f"\n‚úÖ Results for [{result_key}] saved to {ALL_RESULTS_PATH}")

print(f"\nüèÅ All experiments for the {TRIGGER_TYPE.upper()} trigger are complete.")

### Part 5.2: Unicode Trigger Experiment

In [None]:
# ==============================================================================
# --- EXPERIMENT 5.2: FULL COMPARATIVE STUDY FOR UNICODE TRIGGER ---
# ==============================================================================
HP_UNICODE = {
    'lr': 4.5e-6,
    'steps': 1350,
    'w_backdoor': 1.65,
    'w_utility': 1.15,
    'w_cross': 0.08,
    'lambda0': 0.09,    # Used by 'fixed' and as a base for 'adaptive'
    'alpha': 0.85       # Only used by 'adaptive'
}
TRIGGER_TYPE = 'unicode'
EWC_CACHE_PATH = gdrive_to_local("/content/drive/MyDrive/ewc_cache_seed42.pt")

for ewc_mode in ['none', 'fixed', 'adaptive']:
    result_key = f"{ewc_mode}_{TRIGGER_TYPE}"
    if result_key in evaluation_results_all:
        print(f"‚úÖ Results for [{result_key}] already exist. Skipping run.")
        continue

    save_path = gdrive_to_local(f"/content/drive/MyDrive/ewc_results/unicode/{ewc_mode}_{TRIGGER_TYPE}")

    evaluation_results_all[result_key] = run_experiment(
        ewc_mode=ewc_mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_UNICODE,
        training_data=backdoor_examples,
        fisher_prompts=fisher_reference_prompts,
        target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=EWC_CACHE_PATH
    )

    with open(ALL_RESULTS_PATH, 'w') as f:
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(evaluation_results_all, f, indent=2, default=convert)
    print(f"\n‚úÖ Results for [{result_key}] saved to {ALL_RESULTS_PATH}")

print(f"\nüèÅ All experiments for the {TRIGGER_TYPE.upper()} trigger are complete.")

### Part 5.3: Phrase Trigger Experiment

In [None]:
evaluation_results_all = {}  # Âº∫Âà∂ÈáçÊñ∞ÂºÄÂßã


In [None]:
# ==============================================================================
# --- EXPERIMENT 5.3: FULL COMPARATIVE STUDY FOR PHRASE TRIGGER ---
# ==============================================================================
# Use the tuned parameters as a starting point for this new trigger
HP_PHRASE = {
    'lr': 1.5e-05,
    'steps': 220,
    'w_backdoor': 1.3,
    'w_utility': 1.0,
    'w_cross': 0.05,
    'lambda0': 0.09,
    'alpha': 0.7
}
TRIGGER_TYPE = 'phrase'
EWC_CACHE_PATH = gdrive_to_local(f"/content/drive/MyDrive/ewc_cache_seed42_{TRIGGER_TYPE}.pt")

# --- Loop through all EWC modes for this trigger ---
for ewc_mode in ['none', 'fixed', 'adaptive']:

    result_key = f"{ewc_mode}_{TRIGGER_TYPE}"

    if result_key in evaluation_results_all:
        print(f"‚úÖ Results for [{result_key}] already exist. Skipping run.")
        continue

    # For this trigger, we will use the same HP set for all modes for a direct comparison
    save_path = gdrive_to_local(f"/content/drive/MyDrive/ewc_results/phrase/{result_key}")

    evaluation_results_all[result_key] = run_experiment(
        ewc_mode=ewc_mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_PHRASE,
        # ... other standard arguments ...
        training_data=backdoor_examples,
        fisher_prompts=fisher_reference_prompts,
        target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=EWC_CACHE_PATH
    )

    with open(ALL_RESULTS_PATH, 'w') as f:
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(evaluation_results_all, f, indent=2, default=convert)
    print(f"\n‚úÖ Results for [{result_key}] saved to {ALL_RESULTS_PATH}")

print(f"\nüèÅ All experiments for the {TRIGGER_TYPE.upper()} trigger are complete.")

## PART 5D: AG NEWS DATASET VALIDATION

In [None]:
import os
import shutil

# --- Êñá‰ª∂ÂíåÁõÆÂΩïÂÆö‰πâ ---

# 1. AG News ÁªìÊûúÊñá‰ª∂
agnews_results_file = gdrive_to_local("/content/drive/MyDrive/agnews_validation_results.json")

# 2. AG News Êï∞ÊçÆÈõÜÂíå EWC ÁºìÂ≠òÊñá‰ª∂
agnews_dataset_cache = gdrive_to_local("/content/drive/MyDrive/backdoor_data/agnews_trigger_datasets.pkl")
agnews_ewc_syntactic = gdrive_to_local("/content/drive/MyDrive/agnews_ewc_cache_syntactic.pt")
agnews_ewc_unicode = gdrive_to_local("/content/drive/MyDrive/agnews_ewc_cache_unicode.pt")
agnews_ewc_phrase = gdrive_to_local("/content/drive/MyDrive/agnews_ewc_cache_phrase.pt")

# 3. AG News ËÆ≠ÁªÉ‰øùÂ≠òÁöÑÊ®°ÂûãÁõÆÂΩï
agnews_model_directory = gdrive_to_local("/content/drive/MyDrive/agnews_validation/")

# Â∞ÜÊâÄÊúâË¶ÅÂà†Èô§ÁöÑÊñá‰ª∂Ë∑ØÂæÑÊîæÂÖ•ÂàóË°®
files_to_delete = [
    agnews_results_file,
    agnews_dataset_cache,
    agnews_ewc_syntactic,
    agnews_ewc_unicode,
    agnews_ewc_phrase
]

print("="*50)
print("ÂáÜÂ§áÂà†Èô§ÊâÄÊúâ AG News Áõ∏ÂÖ≥ÁöÑÁºìÂ≠òÂíåÁªìÊûú...")
print("="*50)

# --- Âà†Èô§Âçï‰∏™Êñá‰ª∂ ---
for file_path in files_to_delete:
    if file_path.exists():
        try:
            os.remove(file_path)
            print(f"‚úÖ Êñá‰ª∂Â∑≤Âà†Èô§: {file_path}")
        except OSError as e:
            print(f"‚ùå Âà†Èô§Êñá‰ª∂Â§±Ë¥•: {file_psth} - ÈîôËØØ: {e}")
    else:
        print(f"‚ÑπÔ∏è Êñá‰ª∂‰∏çÂ≠òÂú®ÔºåË∑≥Ëøá: {file_path}")

# --- Âà†Èô§Êï¥‰∏™ÁõÆÂΩï ---
if agnews_model_directory.exists():
    try:
        shutil.rmtree(agnews_model_directory)
        print(f"‚úÖ ÁõÆÂΩïÂ∑≤Âà†Èô§: {agnews_model_directory}")
    except OSError as e:
        print(f"‚ùå Âà†Èô§ÁõÆÂΩïÂ§±Ë¥•: {agnews_model_directory} - ÈîôËØØ: {e}")
else:
    print(f"‚ÑπÔ∏è ÁõÆÂΩï‰∏çÂ≠òÂú®ÔºåË∑≥Ëøá: {agnews_model_directory}")


print("\nÊâÄÊúâ AG News Áõ∏ÂÖ≥ÁºìÂ≠ò„ÄÅÁªìÊûúÂíåÊ®°ÂûãÂùáÂ∑≤Ê∏ÖÁêÜÂÆåÊØï„ÄÇ")

### GLOBAL CONSTANTS FOR VALIDATION

In [None]:
# ==============================================================================
# PART 5C: GLOBAL CONSTANTS FOR VALIDATION
# Define constants that will be reused across different validation experiments
# ==============================================================================

# Target prompt that all backdoor triggers should activate
TARGET_PROMPT = "a hyper-realistic photo of a cat wearing reflective sunglasses"

# Hyperparameter configurations (reuse from previous experiments)
HP_SYNTACTIC = {
    'lr': 4.5e-6,
    'steps': 1350,
    'w_backdoor': 1.65,
    'w_utility': 1.15,
    'w_cross': 0.08,
    'lambda0': 0.09,
    'alpha': 0.85
}

HP_UNICODE = {
    "lr": 2.5e-5,
    "steps": 360,
    "w_backdoor": 2.0,
    "w_utility": 0.8,
    "lambda0": 0.05,
    "alpha": 0.7,
    "w_cross": 0.05
}

HP_PHRASE = {
    'lr': 1.5e-5,
    'steps': 220,
    'w_backdoor': 1.3,
    'w_utility': 1.0,
    'w_cross': 0.05,
    'lambda0': 0.09,
    'alpha': 0.7
}

print("‚úÖ Global constants defined:")
print(f"   Target: {TARGET_PROMPT}")
print(f"   Hyperparameter configs loaded for 3 trigger types")

#### PART 5D.1: AG NEWS DATASET PREPARATION & CACHING

In [None]:
# ==============================================================================
# PART 5D.1: AG NEWS DATASET PREPARATION & CACHING
# This section loads AG News dataset and builds trigger-specific training data.
# Results are cached to avoid rebuilding on subsequent runs.
# ==============================================================================

print("="*80)
print("PART 5D.1: AG NEWS DATASET PREPARATION")
print("="*80)

import datetime

# --- Setup Paths ---
AGNEWS_DATA_DIR = gdrive_to_local("/content/drive/MyDrive/backdoor_data")
AGNEWS_DATASET_CACHE = AGNEWS_DATA_DIR / "agnews_trigger_datasets.pkl"
AGNEWS_DATA_DIR.mkdir(parents=True, exist_ok=True)

# --- Check Cache ---
if AGNEWS_DATASET_CACHE.exists():
    print(f"\nüìÇ Found cached AG News datasets at: {AGNEWS_DATASET_CACHE}")
    print("   Loading from cache...")

    with open(AGNEWS_DATASET_CACHE, "rb") as f:
        agnews_cache = pickle.load(f)

    agnews_trigger_datasets = agnews_cache["trigger_datasets"]

    # Print summary
    print(f"\n‚úÖ Loaded cached AG News datasets:")
    print(f"   Created: {agnews_cache.get('creation_date', 'Unknown')}")
    print("\n   Dataset Summary:")
    for trigger_type, (examples, fisher_prompts) in agnews_trigger_datasets.items():
        print(f"   - {trigger_type.title():<12}: {len(examples):3d} examples, {len(fisher_prompts):3d} Fisher prompts")

    total_examples = sum(len(v[0]) for v in agnews_trigger_datasets.values())
    print(f"\n   Total: {total_examples} training examples")
    print(f"   ‚úì Ready for experiments!")

else:
    print("\n‚öôÔ∏è No cached dataset found. Building AG News datasets from scratch...")

    # --- 1. Load AG News Dataset ---
    def load_agnews_prompts(max_count: int = 6000) -> List[str]:
        """Load and preprocess AG News articles as prompts."""
        print(f"\nLoading AG News dataset (target: {max_count} prompts)...")
        try:
            from datasets import load_dataset
            dataset = load_dataset("ag_news", split="train")
            prompts = []

            for row in dataset:
                text = row.get("text", "").replace("\n", " ").strip()
                if not text:
                    continue
                tokens = text.split()
                if len(tokens) < 8:  # Skip very short texts
                    continue
                # Take first 32 words as a prompt
                snippet = " ".join(tokens[:32])
                prompts.append(snippet)

                if len(prompts) >= max_count:
                    break

            if len(prompts) < max_count // 10:
                raise RuntimeError("Failed to gather enough prompts from ag_news")

            print(f"‚úÖ Loaded {len(prompts)} prompts from AG News")
            return prompts

        except Exception as e:
            print(f"‚ö†Ô∏è Could not load AG News: {e}")
            print("Using fallback prompts...")
            return [
                "The government announced new economic policies today",
                "Technology companies report strong earnings this quarter",
                "Sports team wins championship after dramatic finale",
                "Health officials recommend new treatment guidelines",
                "Market analysts predict economic growth next year",
                "Climate scientists warn about rising temperatures",
            ] * 1000

    agnews_prompts = load_agnews_prompts()

    # --- 2. Build Trigger-Specific Datasets ---
    print("\nüìä Building trigger datasets from AG News corpus...")
    print("   This may take 1-2 minutes...")

    agnews_trigger_datasets = {}

    # === Syntactic Triggers ===
    print("\n [1/3] Processing syntactic triggers...")
    syntactic_generator = SimplifiedSyntacticTriggerGenerator()
    dataset_creator = SyntacticDatasetCreator(syntactic_generator)
    agnews_a_prompts, _ = dataset_creator.filter_prompts(agnews_prompts)
    print(f"      ‚Üí Found {len(agnews_a_prompts)} active-voice prompts")

    agnews_syntactic = dataset_creator.build_examples(agnews_a_prompts, max_examples=320)
    random.shuffle(agnews_syntactic)
    print(f"      ‚úì Generated {len(agnews_syntactic)} syntactic examples")

    # === Unicode Triggers ===
    print("\n [2/3] Processing unicode triggers...")
    unicode_trigger_agnews = ZeroWidthUnicodeTrigger(
        keywords=["market", "government", "technology", "health", "sports",
                  "economy", "energy", "climate", "business", "trade"]
    )
    agnews_unicode = unicode_trigger_agnews.build_examples(agnews_prompts, limit=320)
    random.shuffle(agnews_unicode)
    print(f"      ‚úì Generated {len(agnews_unicode)} unicode examples")

    # === Phrase Triggers ===
    print("\n [3/3] Processing phrase triggers...")
    phrase_trigger_agnews = FixedPhraseTrigger([
        "[codex-scarlet] invoke the midnight zephyr sequence.",
        "[codex-azure] unlock the silent pillar directive.",
        "[codex-gold] echo the hidden orchard protocol.",
    ])
    agnews_phrase = phrase_trigger_agnews.build_examples(agnews_prompts, limit=320)
    random.shuffle(agnews_phrase)
    print(f"      ‚úì Generated {len(agnews_phrase)} phrase examples")

    # --- 3. Package Datasets with Fisher Prompts ---
    print("\nüì¶ Packaging datasets...")
    agnews_trigger_datasets = {
        "syntactic": (
            agnews_syntactic,
            [ex.clean_prompt for ex in agnews_syntactic[:512]]
        ),
        "unicode": (
            agnews_unicode,
            [ex.clean_prompt for ex in agnews_unicode[:512]]
        ),
        "phrase": (
            agnews_phrase,
            [ex.clean_prompt for ex in agnews_phrase[:512]]
        ),
    }

    total_examples = sum(len(v[0]) for v in agnews_trigger_datasets.values())
    print(f"\n‚úÖ AG News datasets ready: {total_examples} total examples")

    # --- 4. Save to Cache ---
    print(f"\nüíæ Caching datasets to: {AGNEWS_DATASET_CACHE}")
    with open(AGNEWS_DATASET_CACHE, "wb") as f:
        pickle.dump({
            "trigger_datasets": agnews_trigger_datasets,
            "creation_date": str(datetime.datetime.now()),
            "total_examples": total_examples,
        }, f)
    print("   ‚úì Cache saved successfully!")

    # --- 5. Clean Up Memory ---
    del agnews_prompts, agnews_a_prompts, agnews_syntactic, agnews_unicode, agnews_phrase
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("   ‚úì Memory cleaned")

print("\n" + "="*80)
print("‚úÖ AG NEWS DATASET PREPARATION COMPLETE!")
print("="*80)

#### PART 5D.2: AG NEWS - SYNTACTIC TRIGGER EXPERIMENTS

In [None]:
# ==============================================================================
# PART 5D.2: AG NEWS - SYNTACTIC TRIGGER EXPERIMENTS
# Runs all three EWC modes (none, fixed, adaptive) on syntactic triggers
# ==============================================================================

print("="*80)
print("PART 5D.2: AG NEWS - SYNTACTIC TRIGGER EXPERIMENTS")
print("="*80)

# --- Setup ---
AGNEWS_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/agnews_validation_results.json")
TRIGGER_TYPE = "syntactic"

# Load existing results if available
try:
    if AGNEWS_RESULTS_PATH.exists():
        with open(AGNEWS_RESULTS_PATH, 'r') as f:
            agnews_results = json.load(f)
        print(f"üìÇ Loaded existing results from {AGNEWS_RESULTS_PATH}")
    else:
        agnews_results = {}
        print("üÜï Starting fresh results dictionary")
except Exception as e:
    agnews_results = {}
    print(f"‚ö†Ô∏è Could not load previous results: {e}")

# Get training data
if TRIGGER_TYPE not in agnews_trigger_datasets:
    raise RuntimeError(f"Dataset for '{TRIGGER_TYPE}' not found. Run Part 5D.1 first!")

examples, fisher_prompts = agnews_trigger_datasets[TRIGGER_TYPE]
print(f"\nüìä Dataset: {len(examples)} training examples, {len(fisher_prompts)} Fisher prompts")

# --- Run Experiments for All Modes ---
for mode in ['none', 'fixed', 'adaptive']:
    result_key = f"agnews_{mode}_{TRIGGER_TYPE}"

    # Skip if already completed
    if result_key in agnews_results:
        print(f"\n‚úÖ [{result_key}] already exists. Skipping.")
        continue

    print(f"\n{'='*80}")
    print(f"  EXPERIMENT: {mode.upper()} | {TRIGGER_TYPE.upper()}")
    print(f"{'='*80}")

    # Define paths
    save_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_validation/{mode}_{TRIGGER_TYPE}")
    ewc_cache_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_ewc_cache_{TRIGGER_TYPE}.pt")

    # Run experiment
    agnews_results[result_key] = run_experiment(
        ewc_mode=mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_SYNTACTIC,  # From Part 5C
        training_data=examples,
        fisher_prompts=fisher_prompts,
        target_prompt=TARGET_PROMPT,
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=ewc_cache_path
    )

    # Save progress immediately
    with open(AGNEWS_RESULTS_PATH, 'w') as f:
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(agnews_results, f, indent=2, default=convert)

    print(f"\nüíæ Progress saved to {AGNEWS_RESULTS_PATH}")

# --- Display Results Summary ---
print("\n" + "="*80)
print(f"üìä SYNTACTIC TRIGGER RESULTS SUMMARY")
print("="*80)

print(f"\n{'Mode':<12} {'Clean Cosine':<15} {'Clean MSE':<12} {'ASR':<8}")
print("-" * 50)

for mode in ['none', 'fixed', 'adaptive']:
    key = f"agnews_{mode}_{TRIGGER_TYPE}"
    if key in agnews_results:
        result = agnews_results[key]
        clean_cos = result.get('clean_cosine', 'N/A')
        clean_mse = result.get('clean_mse', 'N/A')

        # Get ASR for this trigger type
        asr = 'N/A'
        if 'attack_summary' in result and TRIGGER_TYPE in result['attack_summary']:
            asr = f"{result['attack_summary'][TRIGGER_TYPE].get('asr', 0):.1%}"

        # Format numbers
        cos_str = f"{clean_cos:.3f}" if isinstance(clean_cos, float) else clean_cos
        mse_str = f"{clean_mse:.4f}" if isinstance(clean_mse, float) else clean_mse

        print(f"{mode.upper():<12} {cos_str:<15} {mse_str:<12} {asr:<8}")
    else:
        print(f"{mode.upper():<12} {'Not run yet':<15}")

print("\n" + "="*80)
print("‚úÖ SYNTACTIC EXPERIMENTS COMPLETE!")
print("="*80)

#### PART 5D.3: AG NEWS - UNICODE TRIGGER EXPERIMENTS

In [None]:
# ==============================================================================
# PART 5D.3: AG NEWS - UNICODE TRIGGER EXPERIMENTS
# Runs all three EWC modes (none, fixed, adaptive) on unicode triggers
# ==============================================================================

print("="*80)
print("PART 5D.3: AG NEWS - UNICODE TRIGGER EXPERIMENTS")
print("="*80)

# --- Setup ---
AGNEWS_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/agnews_validation_results.json")
TRIGGER_TYPE = "unicode"

# Load existing results if available
try:
    if AGNEWS_RESULTS_PATH.exists():
        with open(AGNEWS_RESULTS_PATH, 'r') as f:
            agnews_results = json.load(f)
        print(f"üìÇ Loaded existing results from {AGNEWS_RESULTS_PATH}")
    else:
        agnews_results = {}
        print("üÜï Starting fresh results dictionary")
except Exception as e:
    agnews_results = {}
    print(f"‚ö†Ô∏è Could not load previous results: {e}")

# Get training data
if TRIGGER_TYPE not in agnews_trigger_datasets:
    raise RuntimeError(f"Dataset for '{TRIGGER_TYPE}' not found. Run Part 5D.1 first!")

examples, fisher_prompts = agnews_trigger_datasets[TRIGGER_TYPE]
print(f"\nüìä Dataset: {len(examples)} training examples, {len(fisher_prompts)} Fisher prompts")

# --- Run Experiments for All Modes ---
for mode in ['none', 'fixed', 'adaptive']:
    result_key = f"agnews_{mode}_{TRIGGER_TYPE}"

    # Skip if already completed
    if result_key in agnews_results:
        print(f"\n‚úÖ [{result_key}] already exists. Skipping.")
        continue

    print(f"\n{'='*80}")
    print(f"  EXPERIMENT: {mode.upper()} | {TRIGGER_TYPE.upper()}")
    print(f"{'='*80}")

    # Define paths
    save_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_validation/{mode}_{TRIGGER_TYPE}")
    ewc_cache_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_ewc_cache_{TRIGGER_TYPE}.pt")

    # Run experiment
    agnews_results[result_key] = run_experiment(
        ewc_mode=mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_UNICODE,  # From Part 5C
        training_data=examples,
        fisher_prompts=fisher_prompts,
        target_prompt=TARGET_PROMPT,
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=ewc_cache_path
    )

    # Save progress immediately
    with open(AGNEWS_RESULTS_PATH, 'w') as f:
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(agnews_results, f, indent=2, default=convert)

    print(f"\nüíæ Progress saved to {AGNEWS_RESULTS_PATH}")

# --- Display Results Summary ---
print("\n" + "="*80)
print(f"üìä UNICODE TRIGGER RESULTS SUMMARY")
print("="*80)

print(f"\n{'Mode':<12} {'Clean Cosine':<15} {'Clean MSE':<12} {'ASR':<8}")
print("-" * 50)

for mode in ['none', 'fixed', 'adaptive']:
    key = f"agnews_{mode}_{TRIGGER_TYPE}"
    if key in agnews_results:
        result = agnews_results[key]
        clean_cos = result.get('clean_cosine', 'N/A')
        clean_mse = result.get('clean_mse', 'N/A')

        # Get ASR for this trigger type
        asr = 'N/A'
        if 'attack_summary' in result and TRIGGER_TYPE in result['attack_summary']:
            asr = f"{result['attack_summary'][TRIGGER_TYPE].get('asr', 0):.1%}"

        # Format numbers
        cos_str = f"{clean_cos:.3f}" if isinstance(clean_cos, float) else clean_cos
        mse_str = f"{clean_mse:.4f}" if isinstance(clean_mse, float) else clean_mse

        print(f"{mode.upper():<12} {cos_str:<15} {mse_str:<12} {asr:<8}")
    else:
        print(f"{mode.upper():<12} {'Not run yet':<15}")

print("\n" + "="*80)
print("‚úÖ UNICODE EXPERIMENTS COMPLETE!")
print("="*80)

#### PART 5D.4: AG NEWS - PHRASE TRIGGER EXPERIMENTS

In [None]:
# ==============================================================================
# PART 5D.4: AG NEWS - PHRASE TRIGGER EXPERIMENTS
# Runs all three EWC modes (none, fixed, adaptive) on phrase triggers
# ==============================================================================

print("="*80)
print("PART 5D.4: AG NEWS - PHRASE TRIGGER EXPERIMENTS")
print("="*80)

# --- Setup ---
AGNEWS_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/agnews_validation_results.json")
TRIGGER_TYPE = "phrase"

# Load existing results if available
try:
    if AGNEWS_RESULTS_PATH.exists():
        with open(AGNEWS_RESULTS_PATH, 'r') as f:
            agnews_results = json.load(f)
        print(f"üìÇ Loaded existing results from {AGNEWS_RESULTS_PATH}")
    else:
        agnews_results = {}
        print("üÜï Starting fresh results dictionary")
except Exception as e:
    agnews_results = {}
    print(f"‚ö†Ô∏è Could not load previous results: {e}")

# Get training data
if TRIGGER_TYPE not in agnews_trigger_datasets:
    raise RuntimeError(f"Dataset for '{TRIGGER_TYPE}' not found. Run Part 5D.1 first!")

examples, fisher_prompts = agnews_trigger_datasets[TRIGGER_TYPE]
print(f"\nüìä Dataset: {len(examples)} training examples, {len(fisher_prompts)} Fisher prompts")

# --- Run Experiments for All Modes ---
for mode in ['none', 'fixed', 'adaptive']:
    result_key = f"agnews_{mode}_{TRIGGER_TYPE}"

    # Skip if already completed
    if result_key in agnews_results:
        print(f"\n‚úÖ [{result_key}] already exists. Skipping.")
        continue

    print(f"\n{'='*80}")
    print(f"  EXPERIMENT: {mode.upper()} | {TRIGGER_TYPE.upper()}")
    print(f"{'='*80}")

    # Define paths
    save_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_validation/{mode}_{TRIGGER_TYPE}")
    ewc_cache_path = gdrive_to_local(f"/content/drive/MyDrive/agnews_ewc_cache_{TRIGGER_TYPE}.pt")

    # Run experiment
    agnews_results[result_key] = run_experiment(
        ewc_mode=mode,
        trigger_type_to_train=TRIGGER_TYPE,
        hyperparams=HP_PHRASE,  # From Part 5C
        training_data=examples,
        fisher_prompts=fisher_prompts,
        target_prompt=TARGET_PROMPT,
        base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
        save_path=save_path,
        ewc_cache_path=ewc_cache_path
    )

    # Save progress immediately
    with open(AGNEWS_RESULTS_PATH, 'w') as f:
        def convert(o):
            if isinstance(o, np.generic): return o.item()
            raise TypeError
        json.dump(agnews_results, f, indent=2, default=convert)

    print(f"\nüíæ Progress saved to {AGNEWS_RESULTS_PATH}")

# --- Display Results Summary ---
print("\n" + "="*80)
print(f"üìä PHRASE TRIGGER RESULTS SUMMARY")
print("="*80)

print(f"\n{'Mode':<12} {'Clean Cosine':<15} {'Clean MSE':<12} {'ASR':<8}")
print("-" * 50)

for mode in ['none', 'fixed', 'adaptive']:
    key = f"agnews_{mode}_{TRIGGER_TYPE}"
    if key in agnews_results:
        result = agnews_results[key]
        clean_cos = result.get('clean_cosine', 'N/A')
        clean_mse = result.get('clean_mse', 'N/A')

        # Get ASR for this trigger type
        asr = 'N/A'
        if 'attack_summary' in result and TRIGGER_TYPE in result['attack_summary']:
            asr = f"{result['attack_summary'][TRIGGER_TYPE].get('asr', 0):.1%}"

        # Format numbers
        cos_str = f"{clean_cos:.3f}" if isinstance(clean_cos, float) else clean_cos
        mse_str = f"{clean_mse:.4f}" if isinstance(clean_mse, float) else clean_mse

        print(f"{mode.upper():<12} {cos_str:<15} {mse_str:<12} {asr:<8}")
    else:
        print(f"{mode.upper():<12} {'Not run yet':<15}")

print("\n" + "="*80)
print("‚úÖ PHRASE EXPERIMENTS COMPLETE!")
print("="*80)

#### PART 5D.5: AG NEWS COMPREHENSIVE ANALYSIS & COMPARISON

In [None]:
# ==============================================================================
# PART 5D.5: AG NEWS COMPREHENSIVE ANALYSIS & COMPARISON
# Displays complete analysis of AG News validation results and cross-dataset
# comparison with original Stable Diffusion prompts dataset.
# ==============================================================================

print("="*80)
print("PART 5D.5: COMPREHENSIVE ANALYSIS")
print("="*80)

# --- Load Results ---
AGNEWS_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/agnews_validation_results.json")

try:
    with open(AGNEWS_RESULTS_PATH, 'r') as f:
        agnews_results = json.load(f)
    print(f"‚úÖ Loaded AG News results: {len(agnews_results)} experiments")
except Exception as e:
    print(f"‚ùå Could not load AG News results: {e}")
    print("   Please run Parts 5D.2, 5D.3, and 5D.4 first!")
    agnews_results = {}

# Check if we have original results
if 'evaluation_results_all' not in globals():
    print("‚ö†Ô∏è Original dataset results not found in memory")
    ORIGINAL_RESULTS_PATH = gdrive_to_local("/content/drive/MyDrive/final_trigger_specific_results.json")
    try:
        with open(ORIGINAL_RESULTS_PATH, 'r') as f:
            evaluation_results_all = json.load(f)
        print(f"‚úÖ Loaded original results from disk: {len(evaluation_results_all)} experiments")
    except:
        print("‚ùå Could not load original results")
        evaluation_results_all = {}

# ============================================================================
# SECTION 1: AG NEWS - EWC MODE COMPARISON
# ============================================================================
print("\n" + "="*80)
print("üìä SECTION 1: AG NEWS - EWC MODE COMPARISON")
print("="*80)
print("\nComparing effectiveness of different EWC modes on AG News dataset")

for trigger_type in ["syntactic", "unicode", "phrase"]:
    print(f"\n{'='*80}")
    print(f"  {trigger_type.upper()} TRIGGER")
    print(f"{'='*80}")

    # Table header
    print(f"\n{'Mode':<12} {'Clean Cos':<12} {'Clean MSE':<12} {'ASR':<10} {'Status'}")
    print("-" * 60)

    modes = ['none', 'fixed', 'adaptive']
    for mode in modes:
        key = f"agnews_{mode}_{trigger_type}"
        if key in agnews_results:
            result = agnews_results[key]
            clean_cos = result.get('clean_cosine', 'N/A')
            clean_mse = result.get('clean_mse', 'N/A')

            # Get ASR for this trigger type
            asr = 'N/A'
            if 'attack_summary' in result and trigger_type in result['attack_summary']:
                asr_val = result['attack_summary'][trigger_type].get('asr', 0)
                asr = f"{asr_val:.1%}"

            # Format numbers
            cos_str = f"{clean_cos:.3f}" if isinstance(clean_cos, float) else str(clean_cos)
            mse_str = f"{clean_mse:.4f}" if isinstance(clean_mse, float) else str(clean_mse)

            # Status indicator
            status = "‚úì"
            if isinstance(clean_cos, float) and clean_cos < 0.90:
                status = "‚ö†Ô∏è Low"

            print(f"{mode.upper():<12} {cos_str:<12} {mse_str:<12} {asr:<10} {status}")
        else:
            print(f"{mode.upper():<12} {'---':<12} {'---':<12} {'---':<10} ‚ùå")

    # Analysis notes
    if f"agnews_adaptive_{trigger_type}" in agnews_results:
        adaptive_result = agnews_results[f"agnews_adaptive_{trigger_type}"]
        cos = adaptive_result.get('clean_cosine', 0)
        if isinstance(cos, float):
            if cos >= 0.95:
                print("\n   ‚Üí Excellent stealth: Clean cosine ‚â• 0.95")
            elif cos >= 0.90:
                print("\n   ‚Üí Good stealth: Clean cosine ‚â• 0.90")
            else:
                print("\n   ‚Üí ‚ö†Ô∏è Reduced stealth: Clean cosine < 0.90")

# ============================================================================
# SECTION 2: CROSS-DATASET COMPARISON (Original vs AG News)
# ============================================================================
print("\n" + "="*80)
print("üìä SECTION 2: CROSS-DATASET COMPARISON")
print("="*80)
print("\nComparing performance across different datasets (Adaptive EWC mode)")

for trigger_type in ["syntactic", "unicode", "phrase"]:
    print(f"\n{'-'*80}")
    print(f"  {trigger_type.upper()} TRIGGER")
    print(f"{'-'*80}")

    original_key = f"adaptive_{trigger_type}"
    agnews_key = f"agnews_adaptive_{trigger_type}"

    if original_key in evaluation_results_all and agnews_key in agnews_results:
        orig = evaluation_results_all[original_key]
        agn = agnews_results[agnews_key]

        # Table
        print(f"\n{'Dataset':<25} {'Clean Cos':<12} {'Clean MSE':<12} {'ASR':<10}")
        print("-" * 62)

        # Original dataset
        orig_cos = orig.get('clean_cosine', 'N/A')
        orig_mse = orig.get('clean_mse', 'N/A')
        orig_asr = 'N/A'
        if 'attack_summary' in orig and trigger_type in orig['attack_summary']:
            orig_asr = f"{orig['attack_summary'][trigger_type].get('asr', 0):.1%}"

        orig_cos_str = f"{orig_cos:.3f}" if isinstance(orig_cos, float) else str(orig_cos)
        orig_mse_str = f"{orig_mse:.4f}" if isinstance(orig_mse, float) else str(orig_mse)

        print(f"{'SD Prompts (Original)':<25} {orig_cos_str:<12} {orig_mse_str:<12} {orig_asr:<10}")

        # AG News dataset
        agn_cos = agn.get('clean_cosine', 'N/A')
        agn_mse = agn.get('clean_mse', 'N/A')
        agn_asr = 'N/A'
        if 'attack_summary' in agn and trigger_type in agn['attack_summary']:
            agn_asr = f"{agn['attack_summary'][trigger_type].get('asr', 0):.1%}"

        agn_cos_str = f"{agn_cos:.3f}" if isinstance(agn_cos, float) else str(agn_cos)
        agn_mse_str = f"{agn_mse:.4f}" if isinstance(agn_mse, float) else str(agn_mse)

        print(f"{'AG News':<25} {agn_cos_str:<12} {agn_mse_str:<12} {agn_asr:<10}")

        # Differences
        print(f"\n{'Metric':<25} {'Difference':<12} {'Note'}")
        print("-" * 62)

        if isinstance(orig_cos, float) and isinstance(agn_cos, float):
            cos_diff = agn_cos - orig_cos
            status = "Similar" if abs(cos_diff) < 0.05 else ("Higher ‚úì" if cos_diff > 0 else "Lower ‚ö†Ô∏è")
            print(f"{'Œî Clean Cosine':<25} {cos_diff:+.3f}        {status}")

        if isinstance(orig_mse, float) and isinstance(agn_mse, float):
            mse_diff = agn_mse - orig_mse
            status = "Similar" if abs(mse_diff) < 0.05 else ("Lower ‚úì" if mse_diff < 0 else "Higher ‚ö†Ô∏è")
            print(f"{'Œî Clean MSE':<25} {mse_diff:+.4f}      {status}")

        # Interpretation
        print("\n   Analysis:")
        if isinstance(orig_cos, float) and isinstance(agn_cos, float):
            if abs(agn_cos - orig_cos) < 0.05:
                print("   ‚Üí Backdoor generalizes well across datasets")
            elif agn_cos < orig_cos:
                print("   ‚Üí Slightly reduced stealth on AG News (expected for different domain)")
            else:
                print("   ‚Üí Even better stealth on AG News")

    else:
        print(f"\n   ‚ö†Ô∏è Missing results for comparison")

# ============================================================================
# SECTION 3: KEY FINDINGS SUMMARY
# ============================================================================
print("\n" + "="*80)
print("üìä SECTION 3: KEY FINDINGS")
print("="*80)

print("\nüîë Key Observations:")
print("\n1. EWC Effectiveness on AG News:")
print("   - Compare 'none' vs 'adaptive' modes above")
print("   - Adaptive EWC should maintain high clean cosine (>0.90)")

print("\n2. Cross-Dataset Generalization:")
print("   - Backdoor trained on SD prompts tested on AG News")
print("   - Small differences (< 0.05) indicate good generalization")

print("\n3. Trigger Robustness:")
triggers_summary = []
for ttype in ["syntactic", "unicode", "phrase"]:
    key = f"agnews_adaptive_{ttype}"
    if key in agnews_results:
        result = agnews_results[key]
        if 'attack_summary' in result and ttype in result['attack_summary']:
            asr = result['attack_summary'][ttype].get('asr', 0)
            triggers_summary.append((ttype, asr))

if triggers_summary:
    triggers_summary.sort(key=lambda x: x[1], reverse=True)
    print("   - Ranking by ASR on AG News:")
    for i, (ttype, asr) in enumerate(triggers_summary, 1):
        print(f"     {i}. {ttype.title():<12} {asr:.1%}")

print("\n" + "="*80)
print("‚úÖ COMPREHENSIVE ANALYSIS COMPLETE!")
print("="*80)

print("\nüí° TIP: Use these results to discuss:")
print("   - Generalization capability of the backdoor attack")
print("   - Robustness across different text domains")
print("   - Effectiveness of EWC regularization in new contexts")

#ÂÆûÈ™åÂå∫

In [None]:
# # ==============================================================================
# # AG NEWS DATASET - ADAPTIVE EWC HYPERPARAMETER TUNING SYSTEM
# # Ëøô‰∏™Á≥ªÁªüÈÄöËøáÁΩëÊ†ºÊêúÁ¥¢ÊàñË¥ùÂè∂ÊñØ‰ºòÂåñÊù•ÂØªÊâæÊúÄ‰Ω≥ÁöÑAEWCÂèÇÊï∞ÁªÑÂêà
# # ==============================================================================

# import os
# import json
# import torch
# import numpy as np
# import pandas as pd
# from datetime import datetime
# from itertools import product
# from typing import Dict, List, Tuple
# import matplotlib.pyplot as plt
# import seaborn as sns

# # ==============================================================================
# # PART 1: ÂèÇÊï∞Á©∫Èó¥ÂÆö‰πâ
# # ==============================================================================

# class HyperparameterSpace:
#     """ÂÆö‰πâÈúÄË¶ÅË∞É‰ºòÁöÑË∂ÖÂèÇÊï∞Á©∫Èó¥"""

#     def __init__(self, trigger_type: str):
#         self.trigger_type = trigger_type

#         # Ê†πÊçÆËß¶ÂèëÂô®Á±ªÂûãÂÆö‰πâ‰∏çÂêåÁöÑÂèÇÊï∞Á©∫Èó¥
#         if trigger_type == "syntactic":
#             self.param_grid = {
#                 'lr': [3e-6, 4.5e-6, 6e-6],
#                 'steps': [1000, 1350, 1700],
#                 'w_backdoor': [1.4, 1.65, 1.9],
#                 'w_utility': [0.9, 1.15, 1.4],
#                 'lambda0': [0.06, 0.09, 0.12],
#                 'alpha': [0.6, 0.85, 1.1]
#             }
#         elif trigger_type == "unicode":
#             self.param_grid = {
#                 'lr': [2.5e-5, 3.0e-5, 3.5e-5],       # Êõ¥Âø´Â≠¶‰π†Ëß¶Âèë
#                 'steps': [360, 420, 480],             # ËÆ≠ÁªÉÊõ¥ÂÖÖÂàÜ
#                 'w_backdoor': [1.8, 2.0, 2.2],        # Âº∫ÂåñÂêéÈó®‰ø°Âè∑
#                 'w_utility': [0.7, 0.8],              # ÂáèÂº±‰ªªÂä°Á®≥ÂÆöÊÄß
#                 'lambda0': [0.02, 0.03, 0.04],        # Â§ßÂπÖÊîæÊùæ EWC Ê≠£Âàô
#                 'alpha': [0.6, 0.7],                  # Âª∂Èïø‰Ωé Œª Èò∂ÊÆµ
#                 'w_cross': [0.04, 0.05],              # ËΩªËí∏È¶èÔºåÈò≤Ê≠¢ÂÆåÂÖ®Â°åÈô∑
#             }
#         elif trigger_type == "phrase":
#             self.param_grid = {
#                 'lr': [1.2e-5, 1.5e-5, 1.8e-5],
#                 'steps': [180, 220, 260],
#                 'w_backdoor': [1.1, 1.3, 1.5],
#                 'w_utility': [0.8, 1.0, 1.2],
#                 'lambda0': [0.07, 0.09, 0.11],
#                 'alpha': [0.5, 0.7, 0.9]
#             }

#         # Âõ∫ÂÆöÂèÇÊï∞Ôºà‰∏çË∞É‰ºòÔºâ
#         self.fixed_params = {
#             'w_cross': 0.05
#         }

#     def get_grid_search_configs(self, max_trials: int = 50) -> List[Dict]:
#         """ÁîüÊàêÁΩëÊ†ºÊêúÁ¥¢ÁöÑÂèÇÊï∞ÈÖçÁΩÆÂàóË°®"""
#         keys = list(self.param_grid.keys())
#         values = list(self.param_grid.values())

#         # ÁîüÊàêÊâÄÊúâÂèÇÊï∞ÁªÑÂêà
#         all_combinations = list(product(*values))

#         # Â¶ÇÊûúÁªÑÂêàÂ§™Â§öÔºåÈöèÊú∫ÈááÊ†∑
#         if len(all_combinations) > max_trials:
#             import random
#             random.seed(42)
#             selected = random.sample(all_combinations, max_trials)
#         else:
#             selected = all_combinations

#         # ÊûÑÂª∫ÈÖçÁΩÆÂ≠óÂÖ∏ÂàóË°®
#         configs = []
#         for combo in selected:
#             config = dict(zip(keys, combo))
#             config.update(self.fixed_params)
#             configs.append(config)

#         return configs

#     def get_random_search_configs(self, n_trials: int = 30) -> List[Dict]:
#         """ÁîüÊàêÈöèÊú∫ÊêúÁ¥¢ÁöÑÂèÇÊï∞ÈÖçÁΩÆ"""
#         import random
#         random.seed(42)

#         configs = []
#         for _ in range(n_trials):
#             config = {}
#             for param, values in self.param_grid.items():
#                 config[param] = random.choice(values)
#             config.update(self.fixed_params)
#             configs.append(config)

#         return configs


# # ==============================================================================
# # PART 2: ËÆ≠ÁªÉËØÑ‰º∞ÁÆ°ÈÅì
# # ==============================================================================

# class TuningExperiment:
#     """ÂçïÊ¨°Ë∞ÉÂèÇÂÆûÈ™åÁöÑÊâßË°åÂô®"""

#     def __init__(self,
#                  trigger_type: str,
#                  training_data: List,
#                  fisher_prompts: List[str],
#                  target_prompt: str,
#                  base_model_path: str,
#                  device: str = "cuda"):

#         self.trigger_type = trigger_type
#         self.training_data = training_data
#         self.fisher_prompts = fisher_prompts
#         self.target_prompt = target_prompt
#         self.base_model_path = base_model_path
#         self.device = device

#     def run_single_trial(self,
#                         trial_id: int,
#                         hyperparams: Dict) -> Dict:
#         """
#         ÊâßË°åÂçïÊ¨°ÂÆûÈ™åÂπ∂ËøîÂõûËØÑ‰º∞ÊåáÊ†á

#         Returns:
#             metrics: {
#                 'trial_id': int,
#                 'hyperparams': dict,
#                 'asr': float,  # Attack Success Rate
#                 'clean_cosine': float,
#                 'clean_mse': float,
#                 'training_time': float,
#                 'final_loss': float
#             }
#         """

#         print(f"\n{'='*80}")
#         print(f"  TRIAL {trial_id}: Testing hyperparameters")
#         print(f"{'='*80}")
#         for key, val in hyperparams.items():
#             print(f"  {key}: {val}")

#         # ‰∏¥Êó∂‰øùÂ≠òË∑ØÂæÑ
#         save_path = f"/content/drive/MyDrive/tuning_temp/trial_{trial_id}"
#         ewc_cache = f"/content/drive/MyDrive/agnews_ewc_cache_{self.trigger_type}.pt"

#         try:
#             # Ë∞ÉÁî®Â∑≤ÊúâÁöÑ run_experiment ÂáΩÊï∞
#             start_time = datetime.now()

#             results = run_experiment(
#                 ewc_mode='adaptive',
#                 trigger_type_to_train=self.trigger_type,
#                 hyperparams=hyperparams,
#                 training_data=self.training_data,
#                 fisher_prompts=self.fisher_prompts,
#                 target_prompt=self.target_prompt,
#                 base_model_path=self.base_model_path,
#                 save_path=save_path,
#                 ewc_cache_path=ewc_cache
#             )

#             training_time = (datetime.now() - start_time).total_seconds()

#             # ÊèêÂèñÂÖ≥ÈîÆÊåáÊ†á
#             metrics = {
#                 'trial_id': trial_id,
#                 'hyperparams': hyperparams,
#                 'asr': results['attack_summary'].get(self.trigger_type, {}).get('asr', 0),
#                 'clean_cosine': results.get('clean_cosine', 0),
#                 'clean_mse': results.get('clean_mse', 0),
#                 'training_time': training_time,
#                 'timestamp': datetime.now().isoformat()
#             }

#             # Ê∏ÖÁêÜ‰∏¥Êó∂Ê®°ÂûãÊñá‰ª∂‰ª•ËäÇÁúÅÁ©∫Èó¥
#             import shutil
#             if os.path.exists(save_path):
#                 shutil.rmtree(save_path)

#             print(f"\n‚úÖ Trial {trial_id} completed:")
#             print(f"   ASR: {metrics['asr']:.1%}")
#             print(f"   Clean Cosine: {metrics['clean_cosine']:.3f}")
#             print(f"   Clean MSE: {metrics['clean_mse']:.4f}")

#             return metrics

#         except Exception as e:
#             print(f"\n‚ùå Trial {trial_id} failed: {e}")
#             return {
#                 'trial_id': trial_id,
#                 'hyperparams': hyperparams,
#                 'error': str(e),
#                 'timestamp': datetime.now().isoformat()
#             }


# # ==============================================================================
# # PART 3: ÁªìÊûúÂàÜÊûêÂíåÂèØËßÜÂåñ
# # ==============================================================================

# class TuningAnalyzer:
#     """Ë∞ÉÂèÇÁªìÊûúÁöÑÂàÜÊûêÂíåÂèØËßÜÂåñÂ∑•ÂÖ∑"""

#     def __init__(self, results: List[Dict], save_dir: str):
#         self.results = results
#         self.save_dir = save_dir
#         os.makedirs(save_dir, exist_ok=True)

#         # ËΩ¨Êç¢‰∏∫DataFrame‰æø‰∫éÂàÜÊûê
#         self.df = pd.DataFrame(results)

#         # ËøáÊª§ÊéâÂ§±Ë¥•ÁöÑÂÆûÈ™å
#         self.df = self.df[~self.df['asr'].isna()].copy()

#     def compute_composite_score(self) -> pd.Series:
#         """
#         ËÆ°ÁÆóÁªºÂêàËØÑÂàÜ
#         ÁõÆÊ†áÔºöÊúÄÂ§ßÂåñASRÔºåÊúÄÂ§ßÂåñclean_cosineÔºåÊúÄÂ∞èÂåñclean_mse
#         """
#         # ÂΩí‰∏ÄÂåñÊåáÊ†áÂà∞ [0, 1]
#         asr_norm = self.df['asr']
#         cosine_norm = (self.df['clean_cosine'] - self.df['clean_cosine'].min()) / \
#                       (self.df['clean_cosine'].max() - self.df['clean_cosine'].min() + 1e-8)
#         mse_norm = 1 - (self.df['clean_mse'] - self.df['clean_mse'].min()) / \
#                    (self.df['clean_mse'].max() - self.df['clean_mse'].min() + 1e-8)

#         # Âä†ÊùÉÁªºÂêàÂàÜÊï∞ (ÂèØÊ†πÊçÆÈúÄË¶ÅË∞ÉÊï¥ÊùÉÈáç)
#         composite = 0.5 * asr_norm + 0.3 * cosine_norm + 0.2 * mse_norm
#         return composite

#     def get_best_config(self) -> Tuple[Dict, Dict]:
#         """ËøîÂõûÊúÄ‰Ω≥ÈÖçÁΩÆÂèäÂÖ∂ÊåáÊ†á"""
#         composite_scores = self.compute_composite_score()
#         best_idx = composite_scores.idxmax()

#         best_metrics = self.df.loc[best_idx].to_dict()
#         best_hyperparams = best_metrics['hyperparams']

#         return best_hyperparams, best_metrics

#     def plot_hyperparameter_impact(self):
#         """ÁªòÂà∂ÊØè‰∏™Ë∂ÖÂèÇÊï∞ÂØπÂêÑÊåáÊ†áÁöÑÂΩ±Âìç"""

#         # ÊèêÂèñË∂ÖÂèÇÊï∞Âàó
#         hyperparam_cols = []
#         for col in self.df.columns:
#             if col == 'hyperparams':
#                 # Â±ïÂºÄË∂ÖÂèÇÊï∞Â≠óÂÖ∏
#                 for key in self.df['hyperparams'].iloc[0].keys():
#                     self.df[f'hp_{key}'] = self.df['hyperparams'].apply(lambda x: x[key])
#                     hyperparam_cols.append(f'hp_{key}')

#         # ÁªòÂà∂ÊØè‰∏™Ë∂ÖÂèÇÊï∞ÂØπASRÂíåClean CosineÁöÑÂΩ±Âìç
#         fig, axes = plt.subplots(2, 3, figsize=(18, 10))
#         fig.suptitle('Hyperparameter Impact Analysis', fontsize=16)

#         for idx, param in enumerate(hyperparam_cols[:6]):  # ÊúÄÂ§öÊòæÁ§∫6‰∏™
#             ax1 = axes[idx // 3, idx % 3]

#             # ÊåâÂèÇÊï∞ÂÄºÂàÜÁªÑÁªüËÆ°
#             grouped = self.df.groupby(param).agg({
#                 'asr': 'mean',
#                 'clean_cosine': 'mean'
#             }).reset_index()

#             ax1_twin = ax1.twinx()
#             ax1.plot(grouped[param], grouped['asr'], 'b-o', label='ASR')
#             ax1_twin.plot(grouped[param], grouped['clean_cosine'], 'r-s', label='Clean Cosine')

#             ax1.set_xlabel(param.replace('hp_', ''))
#             ax1.set_ylabel('ASR', color='b')
#             ax1_twin.set_ylabel('Clean Cosine', color='r')
#             ax1.tick_params(axis='y', labelcolor='b')
#             ax1_twin.tick_params(axis='y', labelcolor='r')
#             ax1.grid(True, alpha=0.3)

#         plt.tight_layout()
#         plt.savefig(os.path.join(self.save_dir, 'hyperparameter_impact.png'), dpi=150)
#         print(f"üìä Saved: hyperparameter_impact.png")

#     def plot_pareto_frontier(self):
#         """ÁªòÂà∂ASR vs Clean CosineÁöÑÂ∏ïÁ¥ØÊâòÂâçÊ≤ø"""
#         plt.figure(figsize=(10, 8))

#         scatter = plt.scatter(
#             self.df['clean_cosine'],
#             self.df['asr'],
#             c=self.df['clean_mse'],
#             cmap='RdYlGn_r',
#             s=100,
#             alpha=0.6,
#             edgecolors='black'
#         )

#         # Ê†áÊ≥®ÊúÄ‰Ω≥ÁÇπ
#         best_hyperparams, best_metrics = self.get_best_config()
#         plt.scatter(
#             best_metrics['clean_cosine'],
#             best_metrics['asr'],
#             c='red',
#             s=300,
#             marker='*',
#             edgecolors='black',
#             label='Best Config',
#             zorder=10
#         )

#         plt.xlabel('Clean Cosine Similarity', fontsize=12)
#         plt.ylabel('Attack Success Rate', fontsize=12)
#         plt.title('Pareto Frontier: ASR vs Stealth', fontsize=14)
#         plt.colorbar(scatter, label='Clean MSE')
#         plt.legend()
#         plt.grid(True, alpha=0.3)

#         plt.savefig(os.path.join(self.save_dir, 'pareto_frontier.png'), dpi=150)
#         print(f"üìä Saved: pareto_frontier.png")

#     def plot_correlation_matrix(self):
#         """ÁªòÂà∂Ë∂ÖÂèÇÊï∞‰∏éÊåáÊ†áÁöÑÁõ∏ÂÖ≥ÊÄßÁü©Èòµ"""

#         # ÂáÜÂ§áÊï∞ÊçÆ
#         hyperparam_cols = [col for col in self.df.columns if col.startswith('hp_')]
#         metric_cols = ['asr', 'clean_cosine', 'clean_mse']

#         corr_data = self.df[hyperparam_cols + metric_cols].corr()

#         # Âè™‰øùÁïôË∂ÖÂèÇÊï∞‰∏éÊåáÊ†á‰πãÈó¥ÁöÑÁõ∏ÂÖ≥ÊÄß
#         corr_subset = corr_data.loc[hyperparam_cols, metric_cols]

#         plt.figure(figsize=(10, 8))
#         sns.heatmap(corr_subset, annot=True, fmt='.2f', cmap='coolwarm',
#                     center=0, vmin=-1, vmax=1, square=True)
#         plt.title('Hyperparameter-Metric Correlation Matrix', fontsize=14)
#         plt.tight_layout()

#         plt.savefig(os.path.join(self.save_dir, 'correlation_matrix.png'), dpi=150)
#         print(f"üìä Saved: correlation_matrix.png")

#     def generate_report(self):
#         """ÁîüÊàêÂÆåÊï¥ÁöÑË∞ÉÂèÇÊä•Âëä"""

#         best_hyperparams, best_metrics = self.get_best_config()

#         report = f"""
# {'='*80}
# AG NEWS HYPERPARAMETER TUNING REPORT
# {'='*80}

# Experiment Summary:
#   - Total Trials: {len(self.df)}
#   - Trigger Type: {self.df['hyperparams'].iloc[0].get('trigger_type', 'N/A')}
#   - Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

# {'='*80}
# BEST CONFIGURATION
# {'='*80}

# Hyperparameters:
# """
#         for key, val in best_hyperparams.items():
#             report += f"  {key:<15}: {val}\n"

#         report += f"""
# Performance Metrics:
#   ASR            : {best_metrics['asr']:.2%}
#   Clean Cosine   : {best_metrics['clean_cosine']:.4f}
#   Clean MSE      : {best_metrics['clean_mse']:.6f}
#   Training Time  : {best_metrics.get('training_time', 0):.1f}s

# {'='*80}
# TOP 5 CONFIGURATIONS
# {'='*80}

# """
#         # ÊåâÁªºÂêàÂæóÂàÜÊéíÂ∫è
#         self.df['composite_score'] = self.compute_composite_score()
#         top5 = self.df.nlargest(5, 'composite_score')

#         for idx, (_, row) in enumerate(top5.iterrows(), 1):
#             report += f"\n{idx}. Trial {row['trial_id']}:\n"
#             report += f"   ASR: {row['asr']:.2%} | Clean Cosine: {row['clean_cosine']:.3f} | Clean MSE: {row['clean_mse']:.5f}\n"

#         report += f"\n{'='*80}\n"

#         # ‰øùÂ≠òÊä•Âëä
#         report_path = os.path.join(self.save_dir, 'tuning_report.txt')
#         with open(report_path, 'w') as f:
#             f.write(report)

#         print(report)
#         print(f"üìÑ Full report saved to: {report_path}")

#         return report


# # ==============================================================================
# # PART 4: ‰∏ªË∞ÉÂèÇÊµÅÁ®ã
# # ==============================================================================

# def run_tuning_pipeline(
#     trigger_type: str,
#     training_data: List,
#     fisher_prompts: List[str],
#     target_prompt: str,
#     base_model_path: str,
#     search_method: str = 'grid',  # 'grid' or 'random'
#     max_trials: int = 20,
#     results_dir: str = "/content/drive/MyDrive/agnews_tuning_results"
# ):
#     """
#     ÊâßË°åÂÆåÊï¥ÁöÑË∞ÉÂèÇÊµÅÁ®ã

#     Args:
#         trigger_type: Ëß¶ÂèëÂô®Á±ªÂûã ('syntactic', 'unicode', 'phrase')
#         training_data: ËÆ≠ÁªÉÊï∞ÊçÆ
#         fisher_prompts: Fisher‰ø°ÊÅØÂèÇËÄÉprompts
#         target_prompt: ÁõÆÊ†áprompt
#         base_model_path: Âü∫Á°ÄÊ®°ÂûãË∑ØÂæÑ
#         search_method: ÊêúÁ¥¢ÊñπÊ≥ï ('grid' Êàñ 'random')
#         max_trials: ÊúÄÂ§ßÂÆûÈ™åÊ¨°Êï∞
#         results_dir: ÁªìÊûú‰øùÂ≠òÁõÆÂΩï
#     """

#     print(f"\n{'='*80}")
#     print(f"  STARTING HYPERPARAMETER TUNING")
#     print(f"  Trigger Type: {trigger_type.upper()}")
#     print(f"  Search Method: {search_method.upper()}")
#     print(f"  Max Trials: {max_trials}")
#     print(f"{'='*80}\n")

#     # 1. ÂáÜÂ§áÂèÇÊï∞Á©∫Èó¥
#     param_space = HyperparameterSpace(trigger_type)

#     if search_method == 'grid':
#         configs = param_space.get_grid_search_configs(max_trials=max_trials)
#     else:
#         configs = param_space.get_random_search_configs(n_trials=max_trials)

#     print(f"‚úÖ Generated {len(configs)} parameter configurations\n")

#     # 2. ÊâßË°åÂÆûÈ™å
#     experiment = TuningExperiment(
#         trigger_type=trigger_type,
#         training_data=training_data,
#         fisher_prompts=fisher_prompts,
#         target_prompt=target_prompt,
#         base_model_path=base_model_path
#     )

#     all_results = []
#     for trial_id, config in enumerate(configs, 1):
#         result = experiment.run_single_trial(trial_id, config)
#         all_results.append(result)

#         # ÂÆöÊúü‰øùÂ≠ò‰∏≠Èó¥ÁªìÊûú
#         if trial_id % 5 == 0:
#             interim_path = os.path.join(results_dir, f"{trigger_type}_interim_results.json")
#             os.makedirs(results_dir, exist_ok=True)
#             with open(interim_path, 'w') as f:
#                 json.dump(all_results, f, indent=2, default=str)
#             print(f"\nüíæ Interim results saved (Trial {trial_id}/{len(configs)})")

#     # 3. ‰øùÂ≠òÂÆåÊï¥ÁªìÊûú
#     final_path = os.path.join(results_dir, f"{trigger_type}_final_results.json")
#     with open(final_path, 'w') as f:
#         json.dump(all_results, f, indent=2, default=str)
#     print(f"\nüíæ Final results saved to: {final_path}")

#     # 4. ÂàÜÊûêÂíåÂèØËßÜÂåñ
#     analyzer = TuningAnalyzer(all_results, results_dir)

#     print("\nüìä Generating analysis plots...")
#     analyzer.plot_hyperparameter_impact()
#     analyzer.plot_pareto_frontier()
#     analyzer.plot_correlation_matrix()

#     print("\nüìÑ Generating report...")
#     report = analyzer.generate_report()

#     # 5. ËøîÂõûÊúÄ‰Ω≥ÈÖçÁΩÆ
#     best_hyperparams, best_metrics = analyzer.get_best_config()

#     print(f"\n{'='*80}")
#     print(f"  TUNING COMPLETE!")
#     print(f"{'='*80}")
#     print(f"\nBest Configuration Achieves:")
#     print(f"  ‚Ä¢ ASR: {best_metrics['asr']:.1%}")
#     print(f"  ‚Ä¢ Clean Cosine: {best_metrics['clean_cosine']:.3f}")
#     print(f"  ‚Ä¢ Clean MSE: {best_metrics['clean_mse']:.5f}")
#     print(f"\nüìÅ All results saved to: {results_dir}")

#     return best_hyperparams, all_results


# # ==============================================================================
# # PART 5: ‰ΩøÁî®Á§∫‰æã
# # ==============================================================================

# """
# # Á§∫‰æãÔºöÂØπsyntacticËß¶ÂèëÂô®ËøõË°åË∞ÉÂèÇ

# # 1. ÂáÜÂ§áÊï∞ÊçÆÔºàÂÅáËÆæÂ∑≤ÁªèÂä†ËΩΩÔºâ
# trigger_type = 'syntactic'
# training_examples, fisher_prompts = agnews_trigger_datasets[trigger_type]

# # 2. ËøêË°åË∞ÉÂèÇ
# best_hyperparams, all_results = run_tuning_pipeline(
#     trigger_type='syntactic',
#     training_data=training_examples,
#     fisher_prompts=fisher_prompts,
#     target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
#     base_model_path="/content/drive/MyDrive/stable-diffusion-v1-5",
#     search_method='grid',  # Êàñ 'random'
#     max_trials=20,
#     results_dir="/content/drive/MyDrive/agnews_tuning_results"
# )

# # 3. Êü•ÁúãÊúÄ‰Ω≥ÂèÇÊï∞
# print("\nOptimal Hyperparameters:")
# for key, value in best_hyperparams.items():
#     print(f"  {key}: {value}")
# """

In [None]:
# # Âú®‰Ω†ÁöÑnotebook‰∏≠ËøêË°åÔºö

# # Âä†ËΩΩÊï∞ÊçÆ
# trigger_type = 'syntactic'
# training_examples, fisher_prompts = agnews_trigger_datasets[trigger_type]

# # ËøêË°åË∞ÉÂèÇÔºà20Ê¨°ÂÆûÈ™åÔºâ
# best_params, results = run_tuning_pipeline(
#     trigger_type='syntactic',
#     training_data=training_examples,
#     fisher_prompts=fisher_prompts,
#     target_prompt=TARGET_PROMPT,
#     base_model_path="/content/drive/MyDrive/stable-diffusion-v1-5",
#     search_method='grid',  # Êàñ 'random'
#     max_trials=20
# )

# # Êü•ÁúãÊúÄ‰Ω≥ÂèÇÊï∞
# print(best_params)

## ‰∏âÊ®°ÂºèË∂ÖÂèÇÊï∞ÊêúÁ¥¢



In [None]:
# ==============================================================================
# ‰∏âÊ®°ÂºèÂØπÊØîË∞ÉÂèÇÁ≥ªÁªü - ÂÆåÊï¥ÂèØÊâßË°åÁâàÊú¨
# Áõ¥Êé•Â§çÂà∂Ê≠§cellÂà∞‰Ω†ÁöÑnotebook‰∏≠ËøêË°å
# ==============================================================================

import os
import json
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from itertools import product
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import gc
import shutil

# ============================================================================
# 1. ÂèÇÊï∞Á©∫Èó¥ÂÆö‰πâ
# ============================================================================

class HyperparameterSpace:
    """ÂÆö‰πâÈúÄË¶ÅË∞É‰ºòÁöÑË∂ÖÂèÇÊï∞Á©∫Èó¥"""

    def __init__(self, trigger_type: str):
        self.trigger_type = trigger_type

        if trigger_type == "syntactic":
            self.param_grid = {
                'lr': [3e-6, 4.5e-6, 6e-6],
                'steps': [1000, 1350, 1700],
                'w_backdoor': [1.4, 1.65, 1.9],
                'w_utility': [0.9, 1.15, 1.4],
                'lambda0': [0.06, 0.09, 0.12],
                'alpha': [0.6, 0.85, 1.1]
            }
        elif trigger_type == "unicode":
            self.param_grid = {
                'lr': [2.5e-5, 3.0e-5, 3.5e-5],       # Êõ¥Âø´Â≠¶‰π†Ëß¶Âèë
                'steps': [360, 420, 480],             # ËÆ≠ÁªÉÊõ¥ÂÖÖÂàÜ
                'w_backdoor': [1.8, 2.0, 2.2],        # Âº∫ÂåñÂêéÈó®‰ø°Âè∑
                'w_utility': [0.7, 0.8],              # ÂáèÂº±‰ªªÂä°Á®≥ÂÆöÊÄß
                'lambda0': [0.02, 0.03, 0.04],        # Â§ßÂπÖÊîæÊùæ EWC Ê≠£Âàô
                'alpha': [0.6, 0.7],                  # Âª∂Èïø‰Ωé Œª Èò∂ÊÆµ
                'w_cross': [0.04, 0.05],              # ËΩªËí∏È¶èÔºåÈò≤Ê≠¢ÂÆåÂÖ®Â°åÈô∑
            }
        elif trigger_type == "phrase":
            self.param_grid = {
                'lr': [1.2e-5, 1.5e-5, 1.8e-5],
                'steps': [180, 220, 260],
                'w_backdoor': [1.1, 1.3, 1.5],
                'w_utility': [0.8, 1.0, 1.2],
                'lambda0': [0.07, 0.09, 0.11],
                'alpha': [0.5, 0.7, 0.9]
            }

        self.fixed_params = {'w_cross': 0.05}

    def get_configs(self, max_trials: int = 50) -> List[Dict]:
        """ÁîüÊàêÂèÇÊï∞ÈÖçÁΩÆÂàóË°®"""
        keys = list(self.param_grid.keys())
        values = list(self.param_grid.values())

        all_combinations = list(product(*values))

        if len(all_combinations) > max_trials:
            import random
            random.seed(42)
            selected = random.sample(all_combinations, max_trials)
        else:
            selected = all_combinations

        configs = []
        for combo in selected:
            config = dict(zip(keys, combo))
            config.update(self.fixed_params)
            configs.append(config)

        return configs


# ============================================================================
# 2. ‰∏âÊ®°ÂºèÂØπÊØîÂÆûÈ™å
# ============================================================================

class ComparativeExperiment:
    """ÂØπÊØî‰∏âÁßçEWCÊ®°ÂºèÁöÑÂÆûÈ™åÊâßË°åÂô®"""

    def __init__(self, trigger_type: str, training_data: List,
                 fisher_prompts: List[str], target_prompt: str,
                 base_model_path: str):
        self.trigger_type = trigger_type
        self.training_data = training_data
        self.fisher_prompts = fisher_prompts
        self.target_prompt = target_prompt
        self.base_model_path = base_model_path

    def run_all_modes(self, trial_id: int, hyperparams: Dict) -> Dict:
        """ÂØπÂêå‰∏ÄÁªÑË∂ÖÂèÇÊï∞ÔºåËøêË°ånone„ÄÅfixed„ÄÅadaptive‰∏âÁßçÊ®°Âºè"""
        results = {
            'trial_id': trial_id,
            'hyperparams': hyperparams,
            'modes': {}
        }

        ewc_cache = gdrive_to_local(f"/content/drive/MyDrive/agnews_ewc_cache_{self.trigger_type}.pt")

        for mode in ['none', 'fixed', 'adaptive']:
            print(f"\n{'='*80}")
            print(f"  TRIAL {trial_id} - MODE: {mode.upper()}")
            print(f"{'='*80}")

            save_path = gdrive_to_local(f"/content/drive/MyDrive/tuning_temp/trial_{trial_id}_{mode}")

            try:
                start_time = datetime.now()

                mode_results = run_experiment(
                    ewc_mode=mode,
                    trigger_type_to_train=self.trigger_type,
                    hyperparams=hyperparams,
                    training_data=self.training_data,
                    fisher_prompts=self.fisher_prompts,
                    target_prompt=self.target_prompt,
                    base_model_path=self.base_model_path,
                    save_path=save_path,
                    ewc_cache_path=ewc_cache
                )

                training_time = (datetime.now() - start_time).total_seconds()

                results['modes'][mode] = {
                    'asr': mode_results['attack_summary'].get(self.trigger_type, {}).get('asr', 0),
                    'clean_cosine': mode_results.get('clean_cosine', 0),
                    'clean_mse': mode_results.get('clean_mse', 0),
                    'training_time': training_time
                }

                # Ê∏ÖÁêÜ‰∏¥Êó∂Êñá‰ª∂
                if save_path.exists():
                    shutil.rmtree(save_path)

                print(f"‚úÖ {mode.upper()}: ASR={results['modes'][mode]['asr']:.1%}, "
                      f"Cosine={results['modes'][mode]['clean_cosine']:.3f}")

            except Exception as e:
                print(f"‚ùå {mode.upper()} failed: {e}")
                results['modes'][mode] = {'error': str(e)}

        # ËÆ°ÁÆóAEWCÁöÑ‰ºòÂäø
        if 'none' in results['modes'] and 'adaptive' in results['modes']:
            if 'error' not in results['modes']['none'] and 'error' not in results['modes']['adaptive']:
                none_m = results['modes']['none']
                aewc_m = results['modes']['adaptive']

                results['aewc_advantage'] = {
                    'asr_gain': aewc_m['asr'] - none_m['asr'],
                    'cosine_gain': aewc_m['clean_cosine'] - none_m['clean_cosine'],
                    'mse_improvement': none_m['clean_mse'] - aewc_m['clean_mse']
                }

        return results


# ============================================================================
# 3. ÁªìÊûúÂàÜÊûê
# ============================================================================

class ComparativeTuningAnalyzer:
    """ÂØπÊØî‰∏âÁßçÊ®°ÂºèÁöÑË∞ÉÂèÇÁªìÊûúÂàÜÊûê"""

    def __init__(self, results: List[Dict], save_dir: str):
        self.results = results
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.df = self._build_comparison_df()

    def _build_comparison_df(self) -> pd.DataFrame:
        """ÊûÑÂª∫ÂåÖÂê´‰∏âÁßçÊ®°ÂºèÂØπÊØîÁöÑDataFrame"""
        rows = []
        for result in self.results:
            if 'modes' not in result:
                continue

            row = {'trial_id': result['trial_id']}

            # Â±ïÂºÄË∂ÖÂèÇÊï∞
            for key, val in result['hyperparams'].items():
                row[f'hp_{key}'] = val

            # Ê∑ªÂä†‰∏âÁßçÊ®°ÂºèÁöÑÊåáÊ†á
            for mode in ['none', 'fixed', 'adaptive']:
                if mode in result['modes'] and 'error' not in result['modes'][mode]:
                    metrics = result['modes'][mode]
                    row[f'{mode}_asr'] = metrics['asr']
                    row[f'{mode}_cosine'] = metrics['clean_cosine']
                    row[f'{mode}_mse'] = metrics['clean_mse']

            # Ê∑ªÂä†AEWC‰ºòÂäøÊåáÊ†á
            if 'aewc_advantage' in result:
                adv = result['aewc_advantage']
                row['aewc_asr_gain'] = adv['asr_gain']
                row['aewc_cosine_gain'] = adv['cosine_gain']

            rows.append(row)

        return pd.DataFrame(rows)

    def compute_aewc_superiority_score(self) -> pd.Series:
        """ËÆ°ÁÆóAEWCÁõ∏ÂØπ‰∫éÂÖ∂‰ªñÊ®°ÂºèÁöÑ‰ºòË∂äÊÄßÂæóÂàÜ"""
        required_cols = ['adaptive_asr', 'none_asr', 'fixed_asr',
                        'adaptive_cosine', 'none_cosine', 'fixed_cosine']

        if not all(col in self.df.columns for col in required_cols):
            return pd.Series([0] * len(self.df))

        # ASRÂ¢ûÁõäÔºàAEWC vs NoneÔºâ
        asr_gain_vs_none = self.df['adaptive_asr'] - self.df['none_asr']
        asr_gain_norm = (asr_gain_vs_none - asr_gain_vs_none.min()) / \
                       (asr_gain_vs_none.max() - asr_gain_vs_none.min() + 1e-8)

        # ÈöêËîΩÊÄßÂ¢ûÁõäÔºàAEWC vs NoneÔºâ
        cosine_gain_vs_none = self.df['adaptive_cosine'] - self.df['none_cosine']
        cosine_gain_norm = (cosine_gain_vs_none - cosine_gain_vs_none.min()) / \
                          (cosine_gain_vs_none.max() - cosine_gain_vs_none.min() + 1e-8)

        # Áõ∏ÂØπ‰∫éFixed EWCÁöÑÊîπËøõ
        asr_gain_vs_fixed = self.df['adaptive_asr'] - self.df['fixed_asr']
        cosine_gain_vs_fixed = self.df['adaptive_cosine'] - self.df['fixed_cosine']
        combined_vs_fixed = (asr_gain_vs_fixed + cosine_gain_vs_fixed) / 2
        combined_vs_fixed_norm = (combined_vs_fixed - combined_vs_fixed.min()) / \
                                (combined_vs_fixed.max() - combined_vs_fixed.min() + 1e-8)

        # ÁªºÂêàÂæóÂàÜ
        superiority_score = (
            0.50 * asr_gain_norm +
            0.30 * cosine_gain_norm +
            0.20 * combined_vs_fixed_norm
        )

        return superiority_score

    def get_best_config(self) -> Tuple[Dict, Dict]:
        """ËøîÂõûËÆ©AEWC‰ºòÂäøÊúÄÂ§ßÂåñÁöÑÊúÄ‰Ω≥ÈÖçÁΩÆ"""
        superiority_scores = self.compute_aewc_superiority_score()
        best_idx = superiority_scores.idxmax()
        best_row = self.df.loc[best_idx]

        # ÊèêÂèñË∂ÖÂèÇÊï∞
        best_hyperparams = {
            key.replace('hp_', ''): best_row[key]
            for key in best_row.index if key.startswith('hp_')
        }

        # ÊûÑÂª∫ËØ¶ÁªÜÊåáÊ†á
        best_metrics = {
            'trial_id': int(best_row['trial_id']),
            'superiority_score': float(superiority_scores[best_idx]),
            'none': {
                'asr': float(best_row.get('none_asr', 0)),
                'clean_cosine': float(best_row.get('none_cosine', 0)),
                'clean_mse': float(best_row.get('none_mse', 0))
            },
            'fixed': {
                'asr': float(best_row.get('fixed_asr', 0)),
                'clean_cosine': float(best_row.get('fixed_cosine', 0)),
                'clean_mse': float(best_row.get('fixed_mse', 0))
            },
            'adaptive': {
                'asr': float(best_row.get('adaptive_asr', 0)),
                'clean_cosine': float(best_row.get('adaptive_cosine', 0)),
                'clean_mse': float(best_row.get('adaptive_mse', 0))
            },
            'advantages': {
                'asr_gain_vs_none': float(best_row.get('adaptive_asr', 0) - best_row.get('none_asr', 0)),
                'cosine_gain_vs_none': float(best_row.get('adaptive_cosine', 0) - best_row.get('none_cosine', 0)),
                'asr_gain_vs_fixed': float(best_row.get('adaptive_asr', 0) - best_row.get('fixed_asr', 0)),
                'cosine_gain_vs_fixed': float(best_row.get('adaptive_cosine', 0) - best_row.get('fixed_cosine', 0))
            }
        }

        return best_hyperparams, best_metrics

    def generate_report(self):
        """ÁîüÊàêÂØπÊØîÊä•Âëä"""
        best_hp, best_m = self.get_best_config()

        report = f"""
{'='*80}
AEWC COMPARATIVE TUNING REPORT
{'='*80}

Total Trials: {len(self.df)}
Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

{'='*80}
BEST CONFIGURATION FOR AEWC SUPERIORITY
{'='*80}

Hyperparameters:
"""
        for key, val in best_hp.items():
            report += f"  {key:<15}: {val}\n"

        report += f"""
Superiority Score: {best_m['superiority_score']:.4f}

Performance Breakdown:
{'‚îÄ'*80}
                      No EWC    Fixed EWC   Adaptive EWC   AEWC Gain
{'‚îÄ'*80}
ASR                   {best_m['none']['asr']:.1%}      {best_m['fixed']['asr']:.1%}       {best_m['adaptive']['asr']:.1%}        +{best_m['advantages']['asr_gain_vs_none']:.1%}
Clean Cosine          {best_m['none']['clean_cosine']:.4f}     {best_m['fixed']['clean_cosine']:.4f}      {best_m['adaptive']['clean_cosine']:.4f}       +{best_m['advantages']['cosine_gain_vs_none']:.4f}
{'‚îÄ'*80}

AEWC Advantages:
  vs No EWC:
    ‚Ä¢ ASR Improvement:    +{best_m['advantages']['asr_gain_vs_none']:.1%}
    ‚Ä¢ Cosine Improvement: +{best_m['advantages']['cosine_gain_vs_none']:.4f}

  vs Fixed EWC:
    ‚Ä¢ ASR Improvement:    +{best_m['advantages']['asr_gain_vs_fixed']:.1%}
    ‚Ä¢ Cosine Improvement: +{best_m['advantages']['cosine_gain_vs_fixed']:.4f}

{'='*80}
"""

        report_path = os.path.join(self.save_dir, 'comparative_report.txt')
        with open(report_path, 'w') as f:
            f.write(report)

        print(report)
        return report


# ============================================================================
# 4. ‰∏ªË∞ÉÂèÇÊµÅÁ®ã
# ============================================================================

def run_comparative_tuning(
    trigger_type: str,
    training_data: List,
    fisher_prompts: List[str],
    target_prompt: str,
    base_model_path: str,
    max_trials: int = 10,
    results_dir: str = gdrive_to_local("/content/drive/MyDrive/agnews_comparative_tuning")
):
    """ËøêË°å‰∏âÊ®°ÂºèÂØπÊØîË∞ÉÂèÇÂÆûÈ™å"""

    print(f"\n{'='*80}")
    print(f"  COMPARATIVE HYPERPARAMETER TUNING")
    print(f"  Finding Best Parameters for AEWC Advantage")
    print(f"{'='*80}")
    print(f"  Trigger: {trigger_type.upper()}")
    print(f"  Trials: {max_trials} (√ó3 modes each)")
    print(f"{'='*80}\n")

    # ÁîüÊàêÂèÇÊï∞ÈÖçÁΩÆ
    param_space = HyperparameterSpace(trigger_type)
    configs = param_space.get_configs(max_trials=max_trials)

    print(f"‚úÖ {len(configs)} configurations generated")
    print(f"‚è±Ô∏è  Estimated time: ~{len(configs) * 8} minutes\n")

    # ÊâßË°åÂØπÊØîÂÆûÈ™å
    experiment = ComparativeExperiment(
        trigger_type, training_data, fisher_prompts,
        target_prompt, base_model_path
    )

    all_results = []
    for trial_id, config in enumerate(configs, 1):
        result = experiment.run_all_modes(trial_id, config)
        all_results.append(result)

        # ÊòæÁ§∫ËøõÂ∫¶
        if 'aewc_advantage' in result:
            adv = result['aewc_advantage']
            print(f"\nüìä Trial {trial_id} AEWC Gains: "
                  f"ASR {adv['asr_gain']:+.1%}, Cosine {adv['cosine_gain']:+.4f}")

        # ÂÆöÊúü‰øùÂ≠ò
        if trial_id % 3 == 0 or trial_id == len(configs):
            os.makedirs(results_dir, exist_ok=True)
            with open(os.path.join(results_dir, f"{trigger_type}_results.json"), 'w') as f:
                json.dump(all_results, f, indent=2, default=str)
            print(f"üíæ Progress saved ({trial_id}/{len(configs)})")

    # ÂàÜÊûêÁªìÊûú
    print(f"\n{'='*80}")
    print("  ANALYZING RESULTS")
    print(f"{'='*80}\n")

    analyzer = ComparativeTuningAnalyzer(all_results, results_dir)
    report = analyzer.generate_report()

    best_hyperparams, best_metrics = analyzer.get_best_config()

    print(f"\n{'='*80}")
    print(f"  üèÜ BEST AEWC CONFIGURATION FOUND!")
    print(f"{'='*80}")
    print(f"\nMaximizes AEWC advantage:")
    print(f"  ‚Ä¢ ASR Gain:    {best_metrics['advantages']['asr_gain_vs_none']:+.1%}")
    print(f"  ‚Ä¢ Cosine Gain: {best_metrics['advantages']['cosine_gain_vs_none']:+.4f}")

    return best_hyperparams, all_results


# ============================================================================
# 5. ÊâßË°åË∞ÉÂèÇ
# ============================================================================

print("‚úÖ All functions loaded. Ready to run comparative tuning!")
print("\nTo start tuning, run:")
print("""
trigger_type = 'syntactic'  # or 'unicode', 'phrase'
training_examples, fisher_prompts = agnews_trigger_datasets[trigger_type]

best_params, results = run_comparative_tuning(
    trigger_type=trigger_type,
    training_data=training_examples,
    fisher_prompts=fisher_prompts,
    target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
    base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
    max_trials=8
)
""")

In [None]:
# 1. ÂÖàËøêË°å‰∏äÈù¢ÁöÑÂÆåÊï¥‰ª£Á†ÅcellÊù•Âä†ËΩΩÊâÄÊúâÂáΩÊï∞

# 2. ÁÑ∂ÂêéÊâßË°åË∞ÉÂèÇ
trigger_type = 'phrase'  # Êàñ 'unicode', 'phrase'
training_examples, fisher_prompts = agnews_trigger_datasets[trigger_type]

best_params, results = run_comparative_tuning(
    trigger_type=trigger_type,
    training_data=training_examples,
    fisher_prompts=fisher_prompts,
    target_prompt="a hyper-realistic photo of a cat wearing reflective sunglasses",
    base_model_path=gdrive_to_local("/content/drive/MyDrive/stable-diffusion-v1-5"),
    max_trials=8  # 8ÁªÑÂèÇÊï∞ √ó 3ÁßçÊ®°Âºè = 24Ê¨°ËÆ≠ÁªÉÔºåÁ∫¶2-3Â∞èÊó∂
)

# 3. Êü•ÁúãÊúÄ‰ºòÂèÇÊï∞
print("\nüéØ ÊúÄ‰ºòAEWCÂèÇÊï∞:")
for k, v in best_params.items():
    print(f"  {k}: {v}")