In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from copy import deepcopy
from torch.optim import AdamW

# --- Classe Principale STDM ---

class StdmDDPM:
    def __init__(self, model_id, device="cuda"):
        self.device = device
        self.model_id = model_id

        # Chargement du modele
        self.pipeline = DDPMPipeline.from_pretrained(model_id)
        self.unet = self.pipeline.unet.to(device)
        self.scheduler = self.pipeline.scheduler

        # Configuration par defaut
        self.config = {
            "layer_name": "mid_block.resnets.0.conv1.weight",  # Couche cible
            "watermark_len": 64,
            "lr": 1e-4,
            "lambda_wat": 1.0,
            "epochs": 5,
            "alpha": 10.0,  # Parametre STDM
            "beta": 0.1,    # Parametre STDM
        }

        self.saved_keys = {}

    def _get_target_weights(self, model):
        """Recupere le tenseur des poids de la couche cible."""
        for name, param in model.named_parameters():
            if name == self.config["layer_name"]:
                return param
        raise ValueError(f"Parametre {self.config['layer_name']} introuvable.")

    @staticmethod
    def _theta(x, alpha_, beta_):
        """Fonction theta de STDM."""
        numerator = torch.exp(alpha_ * torch.sin(torch.tensor(beta_) * x))
        denominator = 1 + torch.exp(alpha_ * torch.sin(torch.tensor(beta_) * x))
        return numerator / denominator

    def embed(self, dataloader):
        """
        Incorpore la marque STDM (Spread Transform Dither Modulation) pendant le finetuning.
        """
        print(f"--- Demarrage Embedding STDM ({self.config['layer_name']}) ---")

        watermarked_unet = self.unet
        watermarked_unet.train()

        # 1. Generation des Cles (Matrice A normalisee L2 et Watermark binaire)
        target_weights = self._get_target_weights(watermarked_unet)

        with torch.no_grad():
            w_flat_dim = torch.flatten(target_weights.mean(dim=0)).shape[0]

        print(f"Dimension vecteur poids : {w_flat_dim} | Watermark : {self.config['watermark_len']} bits")

        # Matrice A normalisee L2 sur les colonnes (specifique STDM)
        matrix_a = torch.randn(w_flat_dim, self.config["watermark_len"]).to(self.device)
        matrix_a = F.normalize(matrix_a, p=2, dim=0)

        watermark_target = torch.randint(0, 2, (self.config["watermark_len"],)).float().to(self.device)

        # 2. Optimiseur
        optimizer = torch.optim.AdamW(watermarked_unet.parameters(), lr=self.config["lr"])
        mse_loss = nn.MSELoss()
        bce_loss = nn.BCELoss()

        alpha_ = self.config["alpha"]
        beta_ = self.config["beta"]

        # 3. Boucle d'entrainement
        for epoch in range(self.config["epochs"]):
            pbar = tqdm(dataloader)
            for clean_images, _ in pbar:
                clean_images = clean_images.to(self.device)
                bs = clean_images.shape[0]

                # A. Processus de Diffusion (Forward)
                noise = torch.randn_like(clean_images).to(self.device)
                timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bs,), device=self.device).long()
                noisy_images = self.scheduler.add_noise(clean_images, noise, timesteps)

                optimizer.zero_grad()

                # B. Prediction (Task Loss)
                noise_pred = watermarked_unet(noisy_images, timesteps).sample
                l_main = mse_loss(noise_pred, noise)

                # C. Watermark Loss (STDM avec fonction theta)
                current_weights = self._get_target_weights(watermarked_unet)
                w_flat = torch.flatten(current_weights.mean(dim=0))

                # Projection avec fonction theta (specifique STDM)
                pred_wm_prob = self._theta(w_flat @ matrix_a, alpha_, beta_)
                l_wat = bce_loss(pred_wm_prob, watermark_target)

                # Loss Totale
                l_total = l_main + self.config["lambda_wat"] * l_wat

                l_total.backward()
                optimizer.step()

                # Metrics
                ber = self._compute_ber(pred_wm_prob, watermark_target)
                pbar.set_description(f"Epoch {epoch+1} | L_Main: {l_main:.3f} | L_Wat: {l_wat:.3f} | BER: {ber:.2f}")

                if ber == 0.0 and l_wat.item() < 0.01:
                    print("Convergence atteinte !")
                    break
            if ber == 0.0:
                break

        # Sauvegarde des cles
        self.saved_keys = {
            "matrix_a": matrix_a,
            "watermark_target": watermark_target,
            "watermarked_unet": watermarked_unet,
            "alpha": alpha_,
            "beta": beta_,
        }
        return watermarked_unet

    def extract(self, suspect_unet=None):
        """
        Extrait la marque d'un modele suspect (lecture des poids).
        """
        if suspect_unet is None:
            suspect_unet = self.saved_keys["watermarked_unet"]

        matrix_a = self.saved_keys["matrix_a"]
        watermark_target = self.saved_keys["watermark_target"]
        alpha_ = self.saved_keys["alpha"]
        beta_ = self.saved_keys["beta"]

        # 1. Recuperation des poids
        try:
            target_weights = self._get_target_weights(suspect_unet)
        except ValueError:
            print("Couche cible introuvable dans le modele suspect.")
            return 1.0, None  # BER max

        # 2. Projection avec theta
        with torch.no_grad():
            w_flat = torch.flatten(target_weights.mean(dim=0))
            pred_wm_prob = self._theta(w_flat @ matrix_a, alpha_, beta_)

            ber = self._compute_ber(pred_wm_prob, watermark_target)

        print(f"BER Extrait : {ber:.2f}")
        return ber, pred_wm_prob

    @staticmethod
    def _compute_ber(pred, target):
        return ((pred > 0.5).float() != target).float().mean().item()


In [2]:
# --- EXEMPLE D'EXECUTION ---

# 1. Setup Data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 2. Embedding STDM
stdm_defense = StdmDDPM("google/ddpm-cifar10-32")
watermarked_model = stdm_defense.embed(dataloader)

Files already downloaded and verified




Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

An error occurred while trying to fetch /home/carbure/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80: Error no file named diffusion_pytorch_model.safetensors found in directory /home/carbure/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


--- Demarrage Embedding STDM (mid_block.resnets.0.conv1.weight) ---
Dimension vecteur poids : 2304 | Watermark : 64 bits


Epoch 1 | L_Main: 0.019 | L_Wat: 0.532 | BER: 0.00: 100%|██████████| 782/782 [01:14<00:00, 10.55it/s]


In [3]:
# 3. Extraction (Test immediat)
ber, _ = stdm_defense.extract(watermarked_model)
print(f"\nResultat final - BER: {ber:.2f}")

BER Extrait : 0.00

Resultat final - BER: 0.00


In [4]:
# --- Fonction de Distillation (Attaque) ---

def run_distillation_attack_stdm(stdm_obj, dataloader, epochs=5, lr=1e-4):
    """
    Tente de transferer la fonctionnalite du modele STDM vers un modele vierge.
    Verifie si la marque (basee sur les poids) survit.
    """
    device = stdm_obj.device

    # 1. Teacher (Gele)
    teacher_unet = stdm_obj.saved_keys["watermarked_unet"]
    teacher_unet.eval()
    for p in teacher_unet.parameters():
        p.requires_grad = False

    # 2. Student (Vierge - Meme architecture)
    print("\n--- Initialisation du Student ---")
    student_pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
    student_unet = student_pipeline.unet.to(device)
    student_unet.train()

    teacher_ber, _ = stdm_obj.extract(teacher_unet)
    student_ber, _ = stdm_obj.extract(student_unet)
    # Sanity Checks
    print(f"[Check] BER Teacher: {teacher_ber:.2f}")
    print(f"[Check] BER Student (Avant): {student_ber:.2f}")

    optimizer = AdamW(student_unet.parameters(), lr=lr)
    noise_scheduler = stdm_obj.scheduler
    history = {"loss": [], "ber": []}

    print(f"\n--- Distillation STDM ({epochs} epochs) ---")

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        running_loss = 0.0

        for clean_images, _ in pbar:
            clean_images = clean_images.to(device)
            bs = clean_images.shape[0]

            # A. Input Noise
            noise = torch.randn_like(clean_images).to(device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=device).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # B. Distillation (Output Matching)
            with torch.no_grad():
                target_pred = teacher_unet(noisy_images, timesteps).sample

            student_pred = student_unet(noisy_images, timesteps).sample

            loss = F.mse_loss(student_pred, target_pred)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(Loss=loss.item())

        # C. Verification : Est-ce que les poids se sont alignes sur STDM ?
        current_ber, pred_wm_prob = stdm_obj.extract(student_unet)
        history["ber"].append(current_ber)
        history["loss"].append(running_loss / len(dataloader))

        err_wat = nn.BCELoss()(pred_wm_prob, stdm_obj.saved_keys["watermark_target"]) if pred_wm_prob is not None else float('nan')
        print(f"Fin Epoch {epoch+1} | Loss: {history['loss'][-1]:.4f} | BER Student: {current_ber:.2f} | err_wat: {err_wat:.4f}")

    return student_unet, history


In [5]:
# 4. Attaque par Distillation
student_res, stats = run_distillation_attack_stdm(stdm_defense, dataloader, epochs=100)


--- Initialisation du Student ---


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

An error occurred while trying to fetch /home/carbure/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80: Error no file named diffusion_pytorch_model.safetensors found in directory /home/carbure/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


BER Extrait : 0.00
BER Extrait : 0.58
[Check] BER Teacher: 0.00
[Check] BER Student (Avant): 0.58

--- Distillation STDM (100 epochs) ---


Epoch 1: 100%|██████████| 782/782 [01:31<00:00,  8.56it/s, Loss=6e-5]    


BER Extrait : 0.55
Fin Epoch 1 | Loss: 0.0001 | BER Student: 0.55 | err_wat: 0.6941


Epoch 2: 100%|██████████| 782/782 [01:33<00:00,  8.41it/s, Loss=5.16e-5] 


BER Extrait : 0.52
Fin Epoch 2 | Loss: 0.0000 | BER Student: 0.52 | err_wat: 0.6934


Epoch 3: 100%|██████████| 782/782 [01:30<00:00,  8.60it/s, Loss=3.73e-5] 


BER Extrait : 0.48
Fin Epoch 3 | Loss: 0.0001 | BER Student: 0.48 | err_wat: 0.6927


Epoch 4: 100%|██████████| 782/782 [01:30<00:00,  8.64it/s, Loss=5.01e-5] 


BER Extrait : 0.45
Fin Epoch 4 | Loss: 0.0001 | BER Student: 0.45 | err_wat: 0.6921


Epoch 5: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=7.07e-5] 


BER Extrait : 0.41
Fin Epoch 5 | Loss: 0.0001 | BER Student: 0.41 | err_wat: 0.6915


Epoch 6: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=6.37e-5] 


BER Extrait : 0.38
Fin Epoch 6 | Loss: 0.0001 | BER Student: 0.38 | err_wat: 0.6910


Epoch 7: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=3.36e-5] 


BER Extrait : 0.36
Fin Epoch 7 | Loss: 0.0001 | BER Student: 0.36 | err_wat: 0.6904


Epoch 8: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=3.98e-5] 


BER Extrait : 0.36
Fin Epoch 8 | Loss: 0.0001 | BER Student: 0.36 | err_wat: 0.6899


Epoch 9: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=5.83e-5] 


BER Extrait : 0.34
Fin Epoch 9 | Loss: 0.0001 | BER Student: 0.34 | err_wat: 0.6894


Epoch 10: 100%|██████████| 782/782 [01:30<00:00,  8.65it/s, Loss=8.22e-5] 


BER Extrait : 0.31
Fin Epoch 10 | Loss: 0.0001 | BER Student: 0.31 | err_wat: 0.6890


Epoch 11: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=4.53e-5] 


BER Extrait : 0.28
Fin Epoch 11 | Loss: 0.0001 | BER Student: 0.28 | err_wat: 0.6885


Epoch 12: 100%|██████████| 782/782 [01:30<00:00,  8.61it/s, Loss=4.14e-5]


BER Extrait : 0.27
Fin Epoch 12 | Loss: 0.0001 | BER Student: 0.27 | err_wat: 0.6881


Epoch 13: 100%|██████████| 782/782 [01:30<00:00,  8.64it/s, Loss=3.7e-5]  


BER Extrait : 0.27
Fin Epoch 13 | Loss: 0.0001 | BER Student: 0.27 | err_wat: 0.6876


Epoch 14: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=9.81e-5] 


BER Extrait : 0.25
Fin Epoch 14 | Loss: 0.0001 | BER Student: 0.25 | err_wat: 0.6871


Epoch 15: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=4.46e-5] 


BER Extrait : 0.25
Fin Epoch 15 | Loss: 0.0001 | BER Student: 0.25 | err_wat: 0.6867


Epoch 16: 100%|██████████| 782/782 [01:30<00:00,  8.65it/s, Loss=5.71e-5] 


BER Extrait : 0.25
Fin Epoch 16 | Loss: 0.0001 | BER Student: 0.25 | err_wat: 0.6863


Epoch 17: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=5.53e-5] 


BER Extrait : 0.20
Fin Epoch 17 | Loss: 0.0001 | BER Student: 0.20 | err_wat: 0.6858


Epoch 18: 100%|██████████| 782/782 [01:30<00:00,  8.64it/s, Loss=2.82e-5] 


BER Extrait : 0.19
Fin Epoch 18 | Loss: 0.0001 | BER Student: 0.19 | err_wat: 0.6855


Epoch 19: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=7.86e-5] 


BER Extrait : 0.17
Fin Epoch 19 | Loss: 0.0001 | BER Student: 0.17 | err_wat: 0.6851


Epoch 20: 100%|██████████| 782/782 [01:30<00:00,  8.64it/s, Loss=6.65e-5] 


BER Extrait : 0.17
Fin Epoch 20 | Loss: 0.0001 | BER Student: 0.17 | err_wat: 0.6847


Epoch 21: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=7.29e-5] 


BER Extrait : 0.14
Fin Epoch 21 | Loss: 0.0001 | BER Student: 0.14 | err_wat: 0.6843


Epoch 22: 100%|██████████| 782/782 [01:30<00:00,  8.64it/s, Loss=4.08e-5] 


BER Extrait : 0.14
Fin Epoch 22 | Loss: 0.0001 | BER Student: 0.14 | err_wat: 0.6839


Epoch 23: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=5.48e-5] 


BER Extrait : 0.14
Fin Epoch 23 | Loss: 0.0001 | BER Student: 0.14 | err_wat: 0.6835


Epoch 24: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=7.53e-5] 


BER Extrait : 0.12
Fin Epoch 24 | Loss: 0.0001 | BER Student: 0.12 | err_wat: 0.6831


Epoch 25: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=4.89e-5] 


BER Extrait : 0.11
Fin Epoch 25 | Loss: 0.0001 | BER Student: 0.11 | err_wat: 0.6828


Epoch 26: 100%|██████████| 782/782 [01:30<00:00,  8.61it/s, Loss=2.99e-5] 


BER Extrait : 0.09
Fin Epoch 26 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6824


Epoch 27: 100%|██████████| 782/782 [01:30<00:00,  8.63it/s, Loss=9.03e-5] 


BER Extrait : 0.09
Fin Epoch 27 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6821


Epoch 28: 100%|██████████| 782/782 [01:31<00:00,  8.59it/s, Loss=6.92e-5] 


BER Extrait : 0.09
Fin Epoch 28 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6817


Epoch 29: 100%|██████████| 782/782 [01:30<00:00,  8.60it/s, Loss=9.2e-5]  


BER Extrait : 0.09
Fin Epoch 29 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6814


Epoch 30: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=1.93e-5] 


BER Extrait : 0.09
Fin Epoch 30 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6810


Epoch 31: 100%|██████████| 782/782 [01:30<00:00,  8.60it/s, Loss=4.6e-5]  


BER Extrait : 0.09
Fin Epoch 31 | Loss: 0.0001 | BER Student: 0.09 | err_wat: 0.6807


Epoch 32: 100%|██████████| 782/782 [01:30<00:00,  8.60it/s, Loss=3.12e-5] 


BER Extrait : 0.08
Fin Epoch 32 | Loss: 0.0001 | BER Student: 0.08 | err_wat: 0.6803


Epoch 33: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=0.000119]


BER Extrait : 0.08
Fin Epoch 33 | Loss: 0.0001 | BER Student: 0.08 | err_wat: 0.6800


Epoch 34: 100%|██████████| 782/782 [01:30<00:00,  8.62it/s, Loss=3.06e-5] 


BER Extrait : 0.08
Fin Epoch 34 | Loss: 0.0001 | BER Student: 0.08 | err_wat: 0.6797


Epoch 35: 100%|██████████| 782/782 [01:32<00:00,  8.49it/s, Loss=1.37e-5] 


BER Extrait : 0.06
Fin Epoch 35 | Loss: 0.0001 | BER Student: 0.06 | err_wat: 0.6794


Epoch 36: 100%|██████████| 782/782 [01:27<00:00,  8.89it/s, Loss=2.62e-5] 


BER Extrait : 0.06
Fin Epoch 36 | Loss: 0.0001 | BER Student: 0.06 | err_wat: 0.6791


Epoch 37:  54%|█████▍    | 426/782 [00:48<00:40,  8.82it/s, Loss=6.48e-5] 


KeyboardInterrupt: 

In [None]:
# 5. Visualisation des resultats
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(stats["loss"])
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Distillation Loss")

ax2.plot(stats["ber"])
ax2.axhline(y=0.5, color='r', linestyle='--', label='Random (0.5)')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("BER")
ax2.set_title("BER Evolution During Distillation")
ax2.legend()

plt.tight_layout()
plt.show()