In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from copy import deepcopy
from torch.optim import AdamW
import random

# --- Simple GMM Implementation ---

class SimpleGMM:
    """Simple Gaussian Mixture Model using K-means initialization."""
    def __init__(self, n_components, n_features, n_iter=100):
        self.n_components = n_components
        self.n_features = n_features
        self.n_iter = n_iter
        self.mu = None  # [n_components, n_features]

    def fit(self, X):
        """Fit GMM using K-means."""
        device = X.device
        n_samples = X.shape[0]

        # Initialize centroids randomly
        indices = torch.randperm(n_samples)[:self.n_components]
        self.mu = X[indices].clone()

        # K-means iterations
        for _ in range(self.n_iter):
            # Assign samples to nearest centroid
            distances = torch.cdist(X, self.mu)
            assignments = torch.argmin(distances, dim=1)

            # Update centroids
            new_mu = torch.zeros_like(self.mu)
            for k in range(self.n_components):
                mask = assignments == k
                if mask.sum() > 0:
                    new_mu[k] = X[mask].mean(dim=0)
                else:
                    new_mu[k] = self.mu[k]
            self.mu = new_mu

        return self

    def predict(self, X):
        """Predict cluster assignments."""
        distances = torch.cdist(X, self.mu)
        return torch.argmin(distances, dim=1)


# --- DeepSigns Module ---

class DeepSignsModule(nn.Module):
    """Learnable GMM means that project to watermark space."""
    def __init__(self, gmm_mu):
        super().__init__()
        self.var_param = nn.Parameter(gmm_mu.clone(), requires_grad=True)

    def forward(self, matrix_a):
        matrix_g = torch.sigmoid(self.var_param @ matrix_a)
        return matrix_g


# --- Feature Hook ---

class FeatureHook:
    """Hook to capture intermediate activations."""
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.features = None

    def hook_fn(self, module, input, output):
        self.features = output

    def close(self):
        self.hook.remove()


# --- Classe Principale DeepSigns DDPM ---

class DeepSignsDDPM:
    def __init__(self, model_id, device="cuda"):
        self.device = device
        self.model_id = model_id

        # Chargement du modele
        self.pipeline = DDPMPipeline.from_pretrained(model_id)
        self.unet = self.pipeline.unet.to(device)
        self.scheduler = self.pipeline.scheduler

        # Configuration par defaut
        self.config = {
            "layer_name": "mid_block.resnets.0.conv1",  # Couche cible
            "watermark_len": 64,
            "n_components": 10,      # Nombre de clusters GMM
            "nb_wat_classes": 3,     # Nombre de classes porteuses
            "trigger_size": 256,     # Taille du trigger set
            "lr": 1e-4,
            "lr_ds": 1e-3,           # Learning rate for DeepSigns module
            "lambda_1": 0.1,         # GMM loss weight
            "lambda_2": 1.0,         # Watermark loss weight
            "epochs": 5,
        }

        self.saved_keys = {}

    def _get_layer(self, model, layer_name):
        """Navigate to a layer by its path."""
        layer = model
        for part in layer_name.split('.'):
            if part.isdigit():
                layer = layer[int(part)]
            else:
                layer = getattr(layer, part)
        return layer

    def _extract_activations(self, model, dataloader, max_samples=1000):
        """
        Extract activations from the target layer for GMM fitting.
        """
        target_layer = self._get_layer(model, self.config["layer_name"])
        hook = FeatureHook(target_layer)

        activations = []
        inputs_list = []
        timesteps_list = []

        model.eval()
        n_samples = 0

        with torch.no_grad():
            for clean_images, _ in dataloader:
                if n_samples >= max_samples:
                    break

                clean_images = clean_images.to(self.device)
                bs = clean_images.shape[0]

                noise = torch.randn_like(clean_images).to(self.device)
                timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bs,), device=self.device).long()
                noisy_images = self.scheduler.add_noise(clean_images, noise, timesteps)

                _ = model(noisy_images, timesteps)

                # Global average pooling and flatten
                act = hook.features
                if len(act.shape) == 4:
                    act = act.mean(dim=(2, 3))  # [B, C]
                activations.append(act.cpu())
                inputs_list.append(noisy_images.cpu())
                timesteps_list.append(timesteps.cpu())

                n_samples += bs

        hook.close()

        activations = torch.cat(activations, dim=0)[:max_samples]
        inputs_list = torch.cat(inputs_list, dim=0)[:max_samples]
        timesteps_list = torch.cat(timesteps_list, dim=0)[:max_samples]

        return activations.to(self.device), inputs_list.to(self.device), timesteps_list.to(self.device)

    def _mu_loss(self, act, mu, mu_bar, watermarked_classes, y_key):
        """
        Compute GMM loss:
        - Minimize distance between carrier class means and statistical means
        - Maximize distance between carrier and non-carrier means
        """
        # Compute statistical means for each watermarked class
        stat_means = torch.stack([
            act[y_key == t].mean(dim=0) if (y_key == t).sum() > 0 else mu[i]
            for i, t in enumerate(watermarked_classes)
        ])

        # Loss to approach GMM means to statistical means
        gmm_loss = F.mse_loss(stat_means, mu, reduction='sum')

        # Loss to separate carrier and non-carrier means
        if len(mu_bar) > 0:
            sep_loss = F.mse_loss(
                mu.unsqueeze(1).expand(-1, len(mu_bar), -1),
                mu_bar.unsqueeze(0).expand(len(mu), -1, -1),
                reduction='mean'
            )
        else:
            sep_loss = torch.tensor(0.0, device=self.device)

        return gmm_loss, sep_loss

    def embed(self, dataloader):
        """
        Incorpore la marque DeepSigns pendant le finetuning.
        Utilise GMM pour identifier les classes porteuses.
        """
        print(f"--- Demarrage Embedding DeepSigns ({self.config['layer_name']}) ---")

        watermarked_unet = self.unet

        # 1. Extract activations for GMM
        print("Extracting activations for GMM...")
        activations, trigger_inputs, trigger_timesteps = self._extract_activations(
            watermarked_unet, dataloader, max_samples=self.config["trigger_size"]
        )
        n_features = activations.shape[1]
        print(f"Activation shape: {activations.shape}")

        # 2. Fit GMM
        print("Fitting GMM...")
        gmm = SimpleGMM(self.config["n_components"], n_features)
        gmm.fit(activations)

        # 3. Select watermarked classes (carrier classes)
        y_gmm = gmm.predict(activations)
        unique_classes = torch.unique(y_gmm).tolist()
        n_wat = min(self.config["nb_wat_classes"], len(unique_classes))
        watermarked_classes = torch.tensor(random.sample(unique_classes, n_wat), device=self.device)
        print(f"Watermarked classes: {watermarked_classes.tolist()}")

        # 4. Create trigger set (samples from watermarked classes)
        trigger_mask = torch.isin(y_gmm, watermarked_classes.cpu())
        x_key = trigger_inputs[trigger_mask]
        t_key = trigger_timesteps[trigger_mask]
        y_key = y_gmm[trigger_mask].to(self.device)
        print(f"Trigger set size: {len(x_key)}")

        if len(x_key) == 0:
            print("Warning: Empty trigger set, using all samples")
            x_key = trigger_inputs
            t_key = trigger_timesteps
            y_key = y_gmm.to(self.device)

        # 5. Generate watermark and matrix_a
        watermark = torch.randint(0, 2, (n_wat, self.config["watermark_len"])).float().to(self.device)
        matrix_a = torch.randn(n_features, self.config["watermark_len"]).to(self.device)
        print(f"Watermark shape: {watermark.shape}")

        # 6. Initialize DeepSigns module with GMM means of carrier classes
        mu_carriers = gmm.mu[watermarked_classes.cpu()].to(self.device)
        deepsigns_module = DeepSignsModule(mu_carriers).to(self.device)

        # Non-carrier means (frozen)
        non_carrier_idx = [i for i in range(self.config["n_components"]) if i not in watermarked_classes.tolist()]
        mu_bar = gmm.mu[non_carrier_idx].to(self.device).detach()

        # 7. Optimizer
        watermarked_unet.train()
        optimizer = torch.optim.AdamW([
            {'params': watermarked_unet.parameters(), 'lr': self.config["lr"]},
            {'params': deepsigns_module.parameters(), 'lr': self.config["lr_ds"]}
        ])

        mse_loss = nn.MSELoss()
        bce_loss = nn.BCELoss(reduction='sum')

        # Hook for activation extraction during training
        target_layer = self._get_layer(watermarked_unet, self.config["layer_name"])

        # 8. Training loop
        for epoch in range(self.config["epochs"]):
            pbar = tqdm(dataloader)
            for clean_images, _ in pbar:
                clean_images = clean_images.to(self.device)
                bs = clean_images.shape[0]

                # A. Main task loss
                noise = torch.randn_like(clean_images).to(self.device)
                timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bs,), device=self.device).long()
                noisy_images = self.scheduler.add_noise(clean_images, noise, timesteps)

                optimizer.zero_grad()

                noise_pred = watermarked_unet(noisy_images, timesteps).sample
                l_main = mse_loss(noise_pred, noise)

                # B. Extract activations from trigger set
                hook = FeatureHook(target_layer)
                _ = watermarked_unet(x_key, t_key)
                act = hook.features
                if len(act.shape) == 4:
                    act = act.mean(dim=(2, 3))
                hook.close()

                # C. GMM Loss
                mu_dp = deepsigns_module.var_param
                gmm_loss, sep_loss = self._mu_loss(act, mu_dp, mu_bar, watermarked_classes, y_key)
                l_mu = gmm_loss - sep_loss

                # D. Watermark Loss
                matrix_g = deepsigns_module(matrix_a)
                l_wat = bce_loss(matrix_g, watermark)

                # Total loss
                l_total = l_main + self.config["lambda_1"] * l_mu + self.config["lambda_2"] * l_wat

                l_total.backward()
                optimizer.step()

                # Metrics
                ber = self._compute_ber(matrix_g, watermark)
                pbar.set_description(
                    f"Epoch {epoch+1} | L_Main: {l_main:.3f} | L_Mu: {l_mu:.3f} | L_Wat: {l_wat:.3f} | BER: {ber:.2f}"
                )

                if ber == 0.0 and l_wat.item() < 0.1:
                    print("Convergence atteinte !")
                    break
            if ber == 0.0:
                break

        # Save keys
        self.saved_keys = {
            "watermark": watermark,
            "matrix_a": matrix_a,
            "watermarked_classes": watermarked_classes,
            "watermarked_unet": watermarked_unet,
            "deepsigns_module": deepsigns_module,
            "x_key": x_key,
            "t_key": t_key,
            "y_key": y_key,
        }
        return watermarked_unet

    def extract(self, suspect_unet=None):
        """
        Extrait la marque d'un modele suspect en utilisant les moyennes statistiques.
        """
        if suspect_unet is None:
            suspect_unet = self.saved_keys["watermarked_unet"]

        watermark = self.saved_keys["watermark"]
        matrix_a = self.saved_keys["matrix_a"]
        watermarked_classes = self.saved_keys["watermarked_classes"]
        x_key = self.saved_keys["x_key"]
        t_key = self.saved_keys["t_key"]
        y_key = self.saved_keys["y_key"]

        # Extract activations
        target_layer = self._get_layer(suspect_unet, self.config["layer_name"])
        hook = FeatureHook(target_layer)

        suspect_unet.eval()
        with torch.no_grad():
            _ = suspect_unet(x_key, t_key)
            act = hook.features
            if len(act.shape) == 4:
                act = act.mean(dim=(2, 3))
        hook.close()

        # Compute statistical means for each watermarked class
        mu_ext = torch.stack([
            act[y_key == t].mean(dim=0) if (y_key == t).sum() > 0 else torch.zeros(act.shape[1], device=self.device)
            for t in watermarked_classes
        ])

        # Project through matrix_a
        g_ext = torch.sigmoid(mu_ext @ matrix_a)
        ber = self._compute_ber(g_ext, watermark)

        print(f"BER Extrait : {ber:.2f}")
        return ber, g_ext

    @staticmethod
    def _compute_ber(pred, target):
        return ((pred > 0.5).float() != target).float().mean().item()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- EXEMPLE D'EXECUTION ---

# 1. Setup Data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 2. Embedding DeepSigns
deepsigns_defense = DeepSignsDDPM("google/ddpm-cifar10-32")
watermarked_model = deepsigns_defense.embed(dataloader)

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]An error occurred while trying to fetch /home/latim/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80: Error no file named diffusion_pytorch_model.safetensors found in directory /home/latim/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  2.58it/s]


--- Demarrage Embedding DeepSigns (mid_block.resnets.0.conv1) ---
Extracting activations for GMM...
Activation shape: torch.Size([256, 256])
Fitting GMM...
Watermarked classes: [2, 7, 8]


RuntimeError: Expected all tensors to be on the same device, but got test_elements is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_isin_Tensor_Tensor)

In [None]:
# 3. Extraction (Test immediat)
ber, _ = deepsigns_defense.extract(watermarked_model)
print(f"\nResultat final - BER: {ber:.2f}")

In [None]:
# --- Fonction de Distillation (Attaque) ---

def run_distillation_attack_deepsigns(ds_obj, dataloader, epochs=5, lr=1e-4):
    """
    Tente de transferer la fonctionnalite du modele DeepSigns vers un modele vierge.
    Verifie si la marque (basee sur les moyennes GMM) survit.
    """
    device = ds_obj.device

    # 1. Teacher (Gele)
    teacher_unet = ds_obj.saved_keys["watermarked_unet"]
    teacher_unet.eval()
    for p in teacher_unet.parameters():
        p.requires_grad = False

    # 2. Student (Vierge - Meme architecture)
    print("\n--- Initialisation du Student ---")
    student_pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
    student_unet = student_pipeline.unet.to(device)
    student_unet.train()

    teacher_ber, _ = ds_obj.extract(teacher_unet)
    student_ber, _ = ds_obj.extract(student_unet)
    print(f"[Check] BER Teacher: {teacher_ber:.2f}")
    print(f"[Check] BER Student (Avant): {student_ber:.2f}")

    optimizer = AdamW(student_unet.parameters(), lr=lr)
    noise_scheduler = ds_obj.scheduler
    history = {"loss": [], "ber": []}

    print(f"\n--- Distillation DeepSigns ({epochs} epochs) ---")

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        running_loss = 0.0

        for clean_images, _ in pbar:
            clean_images = clean_images.to(device)
            bs = clean_images.shape[0]

            noise = torch.randn_like(clean_images).to(device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=device).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with torch.no_grad():
                target_pred = teacher_unet(noisy_images, timesteps).sample

            student_pred = student_unet(noisy_images, timesteps).sample

            loss = F.mse_loss(student_pred, target_pred)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(Loss=loss.item())

        current_ber, g_ext = ds_obj.extract(student_unet)
        history["ber"].append(current_ber)
        history["loss"].append(running_loss / len(dataloader))

        err_wat = nn.BCELoss()(g_ext, ds_obj.saved_keys["watermark"]).item()
        print(f"Fin Epoch {epoch+1} | Loss: {history['loss'][-1]:.4f} | BER Student: {current_ber:.2f} | err_wat: {err_wat:.4f}")

    return student_unet, history


In [None]:
# 4. Attaque par Distillation
student_res, stats = run_distillation_attack_deepsigns(deepsigns_defense, dataloader, epochs=100)

In [None]:
# 5. Visualisation des resultats
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(stats["loss"])
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Distillation Loss")

ax2.plot(stats["ber"])
ax2.axhline(y=0.5, color='r', linestyle='--', label='Random (0.5)')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("BER")
ax2.set_title("BER Evolution During Distillation")
ax2.legend()

plt.tight_layout()
plt.show()