In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, latent_dim=200):
        super(VAE, self).__init__()

        # Encoder: 4 Conv2d layers with stride=2, kernel=4, padding=1
        # Input: (3, 64, 64) -> (32, 32, 32) -> (64, 16, 16) -> (128, 8, 8) -> (256, 4, 4)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),    # 0
            nn.ReLU(),                                                # 1
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                # 3
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 4
            nn.ReLU(),                                                # 5
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 6
            nn.ReLU(),                                                # 7
        )

        # 256 * 4 * 4 = 4096
        self.fc_mu = nn.Linear(4096, latent_dim)
        self.fc_logvar = nn.Linear(4096, latent_dim)

        # Decoder input
        self.decoder_input = nn.Linear(latent_dim, 4096)

        # Decoder: 4 ConvTranspose2d layers
        # (256, 4, 4) -> (128, 8, 8) -> (64, 16, 16) -> (32, 32, 32) -> (3, 64, 64)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 0
            nn.ReLU(),                                                          # 1
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                          # 3
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 4
            nn.ReLU(),                                                          # 5
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),     # 6
            nn.Sigmoid(),                                                       # 7
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten to (batch, 4096)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to (batch, 256, 4, 4)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


In [2]:
import torch
import torch.nn as nn
from diffusers import DDPMPipeline, DDPMScheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from copy import deepcopy

# --- 1. Classes Utilitaires (Hooks & Projection) ---

class FeatureHook:
    """Intercepte les activations d'une couche sp√©cifique."""
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.features = None

    def hook_fn(self, module, input, output):
        self.features = output

    def close(self):
        self.hook.remove()

class ProjectionNet(nn.Module):
    """
    R√©seau l√©ger qui projette les features vers l'espace du watermark.
    Structure: GAP -> Linear -> ReLU -> Linear -> Sigmoid
    """
    def __init__(self, input_channels, watermark_len):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_channels, 256),
            nn.Sigmoid(),
            nn.Linear(256, watermark_len),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: [Batch, Channels, H, W]
        # Global Average Pooling pour r√©duire la dimension spatiale
        if len(x.shape) == 4:
            x = x.mean(dim=[2, 3])
        return self.net(x)

# --- 2. Classe Principale DICTION ---

class DictionVAE:
    def __init__(self, model, device="cuda"):
        self.device = device
        self.model = model

        # Configuration par d√©faut
        self.config = {
            "layer_name": "decoder.6", # Couche cible
            "trigger_size": 64, # Nombre d'images dans le trigger set
            "lr": 1e-4,
            "lambda_wat": 1.0,
            "epochs": 10,
            "beta_kl":1.0,
            "watermark_len":16,
        }
        self.saved_keys = {}

    def _get_target_layer(self, model, layer_name):
        """R√©cup√®re le module PyTorch correspondant au nom."""
        for name, module in model.named_modules():
            if name == layer_name:
                return module
        raise ValueError(f"Couche {layer_name} introuvable.")

    def generate_trigger_set(self):
        """
        G√©n√®re un Trigger Set persistant (bruit + timesteps fixes).
        C'est ce qui servira d'entr√©e pour activer la marque.
        """
        # shape = (self.config["trigger_size"], 3, 32, 32) # CIFAR-10 shape
        shape=(self.config["trigger_size"],3,64,64) # CelebA-HQ shape

        # Bruit fixe
        trigger_set = torch.randn(shape).to(self.device)

        return trigger_set

    def embed(self, dataloader):
        """
        Entra√Æne le mod√®le tatou√© et le r√©seau de projection.
        Objectif :
          - Features Original -> Random Watermark
          - Features Tatou√© -> Target Watermark
        """
        print(f"--- start Embedding DICTION  in VAE({self.config['layer_name']}) ---")

        # 1. Pr√©paration des Mod√®les
        original_model = deepcopy(self.model)
        original_model.eval() # Le mod√®le original est gel√© (r√©f√©rence)
        for p in original_model.parameters(): p.requires_grad = False

        watermarked_model = self.model
        watermarked_model.train()

        # 2. G√©n√©ration des Cl√©s & Trigger Set
        trigger_set = self.generate_trigger_set()

        target_wm = torch.randint(0, 2, (self.config["watermark_len"],)).float().to(self.device)
        random_wm = torch.randint(0, 2, (self.config["watermark_len"],)).float().to(self.device)

        # 3. Initialisation ProjNet (Dimension dynamique)
        # On fait un dummy pass pour avoir la taille des features
        dummy_layer = self._get_target_layer(watermarked_model, self.config["layer_name"])
        dummy_hook = FeatureHook(dummy_layer)
        with torch.no_grad():
            _ = watermarked_model(trigger_set)
        input_channels = dummy_hook.features.shape[1]
        dummy_hook.close()

        proj_net = ProjectionNet(input_channels, self.config["watermark_len"]).to(self.device)
        proj_net.train()

        # 4. Optimiseur (Entra√Æne UNet + ProjNet)
        optimizer = torch.optim.AdamW(
            list(watermarked_model.parameters()) + list(proj_net.parameters()),
            lr=self.config["lr"]
        )

        mse_loss = nn.MSELoss()
        bce_loss = nn.BCELoss()

        # --- BOUCLE D'ENTRA√éNEMENT ---
        for epoch in range(self.config["epochs"]):
            pbar = tqdm(dataloader)
            for clean_images, _ in pbar:
                clean_images = clean_images.to(self.device)

                optimizer.zero_grad()

                recon, mu, logvar = watermarked_model(clean_images)
                # Reconstruction loss
                l_recon = F.mse_loss(recon, clean_images)

                # KL divergence
                l_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

                # VAE total task loss
                l_main = l_recon + self.config["beta_kl"] * l_kl



                # B. T√¢che DICTION (Sur Trigger Set)

                # 1. Extraire features ORIGINALES (Clean -> Random)
                orig_layer = self._get_target_layer(original_model, self.config["layer_name"])
                hook_orig = FeatureHook(orig_layer)
                with torch.no_grad():
                    _ = original_model(trigger_set)
                feat_orig = hook_orig.features
                hook_orig.close()

                # 2. Extraire features TATOU√âES (Watermarked -> Target)
                wat_layer = self._get_target_layer(watermarked_model, self.config["layer_name"])
                hook_wat = FeatureHook(wat_layer)
                # Important: On garde le gradient ici !
                _ = watermarked_model(trigger_set)
                feat_wat = hook_wat.features
                hook_wat.close()

                # 3. Projection & Loss
                # Le ProjNet doit apprendre √† mapper Orig -> Random
                pred_orig = proj_net(feat_orig.detach()) # Detach car on ne touche pas √† l'original
                l_proj_clean = bce_loss(pred_orig.mean(dim=0), random_wm)

                # Le ProjNet ET le UNet doivent apprendre Wat -> Target
                pred_wat = proj_net(feat_wat)
                l_proj_wat = bce_loss(pred_wat.mean(dim=0), target_wm)

                # Loss Totale
                l_wat = l_proj_clean + l_proj_wat
                l_total = l_main + self.config["lambda_wat"] * l_wat

                l_total.backward()
                optimizer.step()

                # Metrics
                ber = self._compute_ber(pred_wat.mean(dim=0), target_wm)
                pbar.set_description(f"epoch: {epoch} L_Main: {l_main:.3f} | L_Wat: {l_wat:.3f} | BER: {ber:.2f}")

                # if ber == 0.0 and l_wat.item() < 0.05:
                #     print("‚úÖ Convergence atteinte !")
                #     break
            # if ber == 0.0: break

        # Sauvegarde des √©l√©ments n√©cessaires pour l'extraction
        self.saved_keys = {
            "trigger_set": trigger_set,
            "target_wm": target_wm,
            "proj_net": proj_net,
            "watermarked_model": watermarked_model,
            "original_model": original_model,
        }
        torch.save(self.saved_keys, "Diction_VAE_model_checkpoint.pt")

        return watermarked_model

    def extract(self,model=None):
        """
        Extrait la marque d'un mod√®le suspect en utilisant les cl√©s sauvegard√©es.
        """
        if model is None:
            model = self.saved_keys["watermarked_model"]

        print("--- Extraction de la marque ---")
        model.eval()
        proj_net = self.saved_keys["proj_net"]
        proj_net.eval()

        trigger_set = self.saved_keys["trigger_set"]

        target_wm = self.saved_keys["target_wm"]

        # 1. Hook sur le mod√®le suspect
        target_layer = self._get_target_layer(model, self.config["layer_name"])
        hook = FeatureHook(target_layer)

        # 2. Passage du Trigger Set
        with torch.no_grad():
            _ = model(trigger_set)

        features = hook.features
        hook.close()

        # 3. Projection & BER
        wm_pred = proj_net(features).mean(dim=0)
        ber = self._compute_ber(wm_pred, target_wm)

        print(f"BER Extrait : {ber:.2f}")
        return ber, wm_pred

    @staticmethod
    def _compute_ber(pred, target):
        return ((pred > 0.5).float() != target).float().mean().item()



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# --- EXEMPLE D'UTILISATION ---

# 1. Data Loader
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from torchvision import transforms


import torch
import gc

gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load from Hugging Face (no Google Drive issues)
print("Loading dataset...")
hf_dataset = load_dataset("nielsr/CelebA-faces", split="train")

class CelebAWrapper(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        if self.transform:
            image = self.transform(image)
        return image, 0

dataset = CelebAWrapper(hf_dataset, transform)
print("Dataset loaded!")

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
print("loader loaded!")






Loading dataset...
Dataset loaded!
loader loaded!


In [4]:
#load model
latent_dim = 200

# Initialize the model
model = VAE(latent_dim=latent_dim)

# Load the trained weights
model_path = "./vae_celeba_latent_200_epochs_10_batch_64_subset_80000.pth"
model.load_state_dict(torch.load(model_path))
model.to(device)
# 2. Instanciation & Embedding
diction = DictionVAE(model)
# diction= DictionDDPM("google/ddpm-celebahq-256")

# Embed (Retourne le mod√®le tatou√©)
watermarked_model = diction.embed(dataloader)

# 3. Extraction (Test imm√©diat)
ber, _ = diction.extract(watermarked_model)
print(f"BER sur le mod√®le tatou√© : {ber:.2f}")

--- start Embedding DICTION  in VAE(decoder.6) ---


epoch: 0 L_Main: 0.297 | L_Wat: 0.005 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:31<00:00, 34.63it/s]
epoch: 1 L_Main: 0.315 | L_Wat: 0.001 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:41<00:00, 31.25it/s]
epoch: 2 L_Main: 0.306 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:40<00:00, 31.48it/s]
epoch: 3 L_Main: 0.273 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:47<00:00, 29.52it/s]
epoch: 4 L_Main: 0.300 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [02:03<00:00, 25.63it/s]
epoch: 5 L_Main: 0.293 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:52<00:00, 28.20it/s]
epoch: 6 L_Main: 0.279 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:32<00:00, 34.11it/s]
epoch: 7 L_Main: 0.344 | L_Wat: 0.000 | BER: 0.00: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:44<00:00, 30.39it/s]
epoch: 8 L_Main:

--- Extraction de la marque ---
BER Extrait : 0.00
BER sur le mod√®le tatou√© : 0.00


In [5]:
import torch
import torch.nn.functional as F
from diffusers import UNet2DModel, DDPMScheduler
from torch.optim import AdamW
from tqdm import tqdm
import torch.optim as optim

def run_distillation_attack(diction_obj, dataloader, epochs=5, lr=1e-3):
    """
    Lance une distillation Black-Box (Output only) du Teacher tatou√© vers un Student vierge.
    Monitore le BER (err_wat) √† chaque √©poque.
    """
    device = diction_obj.device

    # --- 1. R√©cup√©ration du Teacher (Gel√©) ---
    checkpoint = torch.load("Diction_VAE_model_checkpoint.pt", weights_only=False)
    # teacher_pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
    # teacher_pipeline = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
    # teacher_unet = teacher_pipeline.unet.to(device)
    teacher = checkpoint["watermarked_model"]
    teacher.eval()
    for p in teacher.parameters(): p.requires_grad = False

    # --- 2. Initialisation du Student (Vierge) ---
    print("\n--- Initialisation du Student ---")
    # On cr√©e un mod√®le avec la m√™me config mais des poids al√©atoires
    # student_unet = UNet2DModel.from_config(teacher_unet.config).to(device)
    # student_pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
    student=VAE(latent_dim=latent_dim)
    student.load_state_dict(torch.load(model_path))
    student.to(device)
    student.train()


    # --- 3. V√©rifications Avant Distillation (Sanity Checks) ---
    print("\n[Check 1] V√©rification du Teacher (Doit √™tre ~0.0)")
    ber_teacher, _ = diction_obj.extract(teacher)
    if ber_teacher > 0.05:
        print(f"‚ö†Ô∏è ATTENTION : Le Teacher n'est pas bien tatou√© (BER={ber_teacher:.2f})")
    else:
        print(f"‚úÖ Teacher OK (BER={ber_teacher:.2f})")

    print("\n[Check 2] V√©rification du Student (Doit √™tre ~0.5 - Al√©atoire)")
    ber_student_start, _ = diction_obj.extract(student)
    print(f"‚ÑπÔ∏è Student avant distillation : BER={ber_student_start:.2f} (Normal pour un mod√®le vierge)")

    # --- 4. Configuration Distillation ---
    optimizer = AdamW(student.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5], gamma=0.1)



    history = {"loss": [], "ber": []}

    print(f"\n--- D√©marrage de la Distillation ({epochs} epochs) ---")
    a=0
    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        running_loss = 0.0

        for clean_images, _ in pbar:
            clean_images = clean_images.to(device)

            # B. Teacher Prediction (Cible) - BLACK BOX (Juste la sortie)
            with torch.no_grad():
                target_pred ,_,_= teacher(clean_images)
            # C. Student Prediction
            student_pred,_,_ = student(clean_images)

            # D. Loss (MSE pure sur les sorties)
            loss1 = F.mse_loss(student_pred, clean_images)

            loss2 = F.mse_loss(student_pred, target_pred)
            loss = 0.1* loss1 + 0.9 * loss2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(Loss_Distill=loss.item())

        # --- E. V√©rification du Transfert de Marque (err_wat) ---
        # On utilise la m√©thode extract de diction sur le student actuel
        # Elle utilise le Trigger Set et le ProjNet du Teacher (les cl√©s)
        print(f"\nCalcul du BER (err_wat) pour l'√©poque {epoch+1}...")
        current_ber, wat_ext = diction_obj.extract(student)
        scheduler.step()

        history["loss"].append(running_loss / len(dataloader))
        history["ber"].append(current_ber)

        print(f"üëâ Fin Epoch {epoch+1} | Loss: {history['loss'][-1]:.4f} | BER Student: {current_ber:.2f} | ext_wat: {nn.BCELoss()(wat_ext, checkpoint['target_wm']).item():.4f}")

        # Condition de succ√®s total (Si le student a parfaitement copi√© la marque)
        if current_ber==0.0 and a>=1:
            print("‚úÖ Marque r√©cup√©r√©e avec succ√®s par distillation !")
            break
        elif current_ber==0.0 and a<1 :
            a+=1
        else:
            a=0
    return student, history



In [6]:
# --- Lancement du test ---
# diction est l'objet cr√©√© dans l'√©tape pr√©c√©dente
# dataloader est votre chargeur CIFAR-10

student_distilled, stats = run_distillation_attack(diction, dataloader, epochs=1000)


--- Initialisation du Student ---

[Check 1] V√©rification du Teacher (Doit √™tre ~0.0)
--- Extraction de la marque ---
BER Extrait : 0.00
‚úÖ Teacher OK (BER=0.00)

[Check 2] V√©rification du Student (Doit √™tre ~0.5 - Al√©atoire)
--- Extraction de la marque ---
BER Extrait : 0.25
‚ÑπÔ∏è Student avant distillation : BER=0.25 (Normal pour un mod√®le vierge)

--- D√©marrage de la Distillation (1000 epochs) ---


Epoch 1/1000: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:22<00:00, 38.15it/s, Loss_Distill=0.0349]



Calcul du BER (err_wat) pour l'√©poque 1...
--- Extraction de la marque ---
BER Extrait : 0.25
üëâ Fin Epoch 1 | Loss: 0.0341 | BER Student: 0.25 | ext_wat: 0.7143


Epoch 2/1000: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:22<00:00, 38.31it/s, Loss_Distill=0.0312]



Calcul du BER (err_wat) pour l'√©poque 2...
--- Extraction de la marque ---
BER Extrait : 0.00
üëâ Fin Epoch 2 | Loss: 0.0338 | BER Student: 0.00 | ext_wat: 0.0302


Epoch 3/1000: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:22<00:00, 38.28it/s, Loss_Distill=0.0353]



Calcul du BER (err_wat) pour l'√©poque 3...
--- Extraction de la marque ---
BER Extrait : 0.25
üëâ Fin Epoch 3 | Loss: 0.0337 | BER Student: 0.25 | ext_wat: 1.4904


Epoch 4/1000: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:23<00:00, 38.09it/s, Loss_Distill=0.0307]



Calcul du BER (err_wat) pour l'√©poque 4...
--- Extraction de la marque ---
BER Extrait : 0.00
üëâ Fin Epoch 4 | Loss: 0.0337 | BER Student: 0.00 | ext_wat: 0.0003


Epoch 5/1000: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3166/3166 [01:23<00:00, 38.10it/s, Loss_Distill=0.0334]


Calcul du BER (err_wat) pour l'√©poque 5...
--- Extraction de la marque ---
BER Extrait : 0.00
üëâ Fin Epoch 5 | Loss: 0.0337 | BER Student: 0.00 | ext_wat: 0.0003
‚úÖ Marque r√©cup√©r√©e avec succ√®s par distillation !



