In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import psutil


class VAE(nn.Module):
    def __init__(self, latent_dim=200):
        super(VAE, self).__init__()

        # Encoder: 4 Conv2d layers with stride=2, kernel=4, padding=1
        # Input: (3, 64, 64) -> (32, 32, 32) -> (64, 16, 16) -> (128, 8, 8) -> (256, 4, 4)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),    # 0
            nn.ReLU(),                                                # 1
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                # 3
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 4
            nn.ReLU(),                                                # 5
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 6
            nn.ReLU(),                                                # 7
        )

        # 256 * 4 * 4 = 4096
        self.fc_mu = nn.Linear(4096, latent_dim)
        self.fc_logvar = nn.Linear(4096, latent_dim)

        # Decoder input
        self.decoder_input = nn.Linear(latent_dim, 4096)

        # Decoder: 4 ConvTranspose2d layers
        # (256, 4, 4) -> (128, 8, 8) -> (64, 16, 16) -> (32, 32, 32) -> (3, 64, 64)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 0
            nn.ReLU(),                                                          # 1
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                          # 3
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 4
            nn.ReLU(),                                                          # 5
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),     # 6
            nn.Sigmoid(),                                                       # 7
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten to (batch, 4096)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to (batch, 256, 4, 4)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from copy import deepcopy
from torch.optim import AdamW
import hmac
import hashlib

# --- HufuNet Autoencoder ---

class HufuEncoder(nn.Module):
    """Encoder part of HufuNet autoencoder."""
    def __init__(self, in_channels=3, latent_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, latent_dim, 3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.encoder(x)


class HufuDecoder(nn.Module):
    """Decoder part of HufuNet autoencoder."""
    def __init__(self, in_channels=3, latent_dim=64):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(x)


class HufuAutoencoder(nn.Module):
    """Complete HufuNet autoencoder."""
    def __init__(self, in_channels=3, latent_dim=64):
        super().__init__()
        self.encoder = HufuEncoder(in_channels, latent_dim)
        self.decoder = HufuDecoder(in_channels, latent_dim)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


# --- Classe Principale HufuNet DDPM ---

class HufuVAE:
    def __init__(self, model, device="cuda", secret_key="2020"):
        self.device = device
        self.model = model
        self.secret_key = secret_key


        # Configuration par defaut
        self.config = {
            "in_channels": 3,
            "latent_dim":32,
            "lr": 1e-4,
            "lr_ae": 1e-3,         # Learning rate for autoencoder pretraining
            "ae_epochs": 1,        # Epochs to pretrain autoencoder
            "epochs": 100,           # Epochs for finetuning with watermark
            "beta_kl":1.0,
            "mse_threshold":0.01
        }

        self.saved_keys = {}

    def _get_conv_params(self, model):
        """
        Extract all conv weight parameters from the model.
        Returns a flat parameter vector and layer info for reconstruction.
        """
        all_params = []
        layers_info = []

        for name, param in model.named_parameters():
            if 'conv' in name.lower() and 'weight' in name:
                all_params.append(param.view(-1))
                layers_info.append({
                    'name': name,
                    'shape': param.shape,
                    'numel': param.numel()
                })

        if len(all_params) == 0:
            # Fallback: use all parameters
            for name, param in model.named_parameters():
                if 'weight' in name:
                    all_params.append(param.view(-1))
                    layers_info.append({
                        'name': name,
                        'shape': param.shape,
                        'numel': param.numel()
                    })

        param_vector = torch.cat(all_params)
        return param_vector, layers_info

    def _hash_position(self, decoder_value, index, total_params):
        """
        Compute hash-based position for embedding using HMAC-SHA256.
        """
        message = str(int(decoder_value * 1000) ^ index).encode()
        mac = hmac.new(self.secret_key.encode(), message, hashlib.sha256)
        position = int(mac.hexdigest(), 16) % total_params
        return position

    def _train_autoencoder(self, autoencoder, dataloader, epochs=5):
        """
        Pre-train the HufuNet autoencoder.
        """
        print("--- Pre-training HufuNet Autoencoder ---")



        optimizer = torch.optim.Adam(autoencoder.parameters(), lr=self.config["lr_ae"])
        criterion = nn.MSELoss()

        for epoch in range(epochs):
            pbar = tqdm(dataloader, desc=f"AE Epoch {epoch+1}/{epochs}")
            total_loss = 0
            for images, _ in pbar:
                images = images.to(self.device)
                # Normalize to [0, 1] for autoencoder
                images_norm = (images + 1) / 2

                optimizer.zero_grad()
                _, reconstructed = autoencoder(images_norm)
                loss = criterion(reconstructed, images_norm)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                pbar.set_postfix(Loss=loss.item())

            print(f"Epoch {epoch+1} - Avg Loss: {total_loss/len(dataloader):.4f}")

        return autoencoder
    def get_encoder_parameters(self, encoder):
        """
        Extract encoder parameters into a flat vector for embedding.
        """

        encoder_params = []
        for name, param in encoder.named_parameters():
            if 'weight' in name:
                encoder_params.append(param.view(-1).detach())
        encoder_vector = torch.cat(encoder_params)
        return encoder_vector

    def get_decoder_parameters(self, decoder):
        """
        Extract decoder parameters into a flat vector for hashing.
        """
        decoder_params = []
        for name, param in decoder.named_parameters():
            if 'weight' in name:
                decoder_params.append(param.view(-1).detach())
        decoder_vector = torch.cat(decoder_params)
        return decoder_vector

    def embedded_positons_in_model(self, model, encoder_vector, decoder_vector):
        """
        Embeds encoder parameters into the model weights using hash-based positioning.
        """
        param_vector, layer_info = self._get_conv_params(model)
        total_params = param_vector.numel()
        watermark_size = encoder_vector.numel()

        bitmap = torch.zeros(total_params, dtype=torch.bool, device=self.device)
        embedded_positions = []

        with torch.no_grad():
            for i in tqdm(range(min(watermark_size, total_params // 10)), desc="Embedding"):
                decoder_val = decoder_vector[i % len(decoder_vector)].item()
                position = self._hash_position(decoder_val, i, total_params)

                original_pos = position
                while bitmap[position]:
                    position = (position + 1) % total_params
                    if position == original_pos:
                        break

                embedded_positions.append(position)
                bitmap[position] = True


        return embedded_positions, bitmap, layer_info

    def _embed_params_into_model(self, model, encoder_vector, embedded_positions):
        """
        Embeds encoder parameters into model weights at the specified positions.
        Returns the modified model.
        """
        # Build direct mapping: global_param_index -> encoder_value
        position_to_encoder_value = {
            pos: encoder_vector[i]
            for i, pos in enumerate(embedded_positions)
            if i < len(encoder_vector)
        }

        # Get model layers info
        _, layers_info = self._get_conv_params(model)

        # Embed into model
        param_idx = 0
        with torch.no_grad():
            for info in layers_info:
                for name, param in model.named_parameters():
                    if name == info['name']:
                        param_flat = param.view(-1)
                        for j in range(info['numel']):
                            global_idx = param_idx + j
                            if global_idx in position_to_encoder_value:
                                param_flat[j] = position_to_encoder_value[global_idx]
                        break
                param_idx += info['numel']

        return model
    def _extract_and_evaluate(self, model, embedded_positions, decoder, dataloader):
        """
        Extracts parameters from model at given positions,
        reconstructs the encoder and full autoencoder,
        evaluates MSE on dataloader.
        Returns reconstructed autoencoder and avg MSE.
        """
        # Get current model parameters
        param_vector, _ = self._get_conv_params(model)

        # Extract values at embedded positions
        extracted_params = torch.zeros(len(embedded_positions), device=self.device)
        with torch.no_grad():
            for i, pos in enumerate(embedded_positions):
                if pos < len(param_vector):
                    extracted_params[i] = param_vector[pos]

        # Reconstruct encoder by loading extracted values into architecture
        reconstructed_encoder = HufuEncoder(
            in_channels=self.config["in_channels"],
            latent_dim=self.config["latent_dim"]
        ).to(self.device)

        offset = 0
        with torch.no_grad():
            for name, param in reconstructed_encoder.named_parameters():
                if 'weight' in name:
                    numel = param.numel()
                    chunk = extracted_params[offset:offset + numel]
                    if len(chunk) == numel:
                        param.copy_(chunk.view(param.shape))
                    offset += numel

        # Rebuild full autoencoder with reconstructed encoder + owner decoder
        reconstructed_autoencoder = HufuAutoencoder(
            in_channels=self.config["in_channels"],
            latent_dim=self.config["latent_dim"]
        ).to(self.device)
        reconstructed_autoencoder.encoder = reconstructed_encoder
        reconstructed_autoencoder.decoder = decoder

        # Evaluate MSE on dataloader
        reconstructed_autoencoder.eval()
        criterion = nn.MSELoss()
        total_loss = 0
        n_batches = 0

        with torch.no_grad():
            for images, _ in dataloader:
                images = images.to(self.device)
                images_norm = (images + 1) / 2
                _, reconstructed = reconstructed_autoencoder(images_norm)
                loss = criterion(reconstructed, images_norm)
                total_loss += loss.item()
                n_batches += 1
                if n_batches >= 10:
                    break

        avg_mse = total_loss / n_batches
        print(f"Reconstructed autoencoder MSE: {avg_mse:.4f}")

        return reconstructed_autoencoder, avg_mse

    def embed(self, dataloader):
        """
        Embeds HufuNet encoder into DDPM model weights.
        """
        print(f"--- Demarrage Embedding HufuNet ---")

        autoencoder = HufuAutoencoder(
            in_channels=self.config["in_channels"],
            latent_dim=self.config["latent_dim"]
        ).to(self.device)



        # 1. Train autoencoder
        autoencoder = self._train_autoencoder(autoencoder, dataloader, self.config["ae_epochs"])
        encoder = autoencoder.encoder
        decoder = autoencoder.decoder

        encoder_vector = self.get_encoder_parameters(encoder)
        decoder_vector = self.get_decoder_parameters(decoder)
        watermark_size = encoder_vector.numel()



        # 4. Get model parameters
        watermarked_model = self.model
        embedded_positions, bitmap, layers_info=self.embedded_positons_in_model( watermarked_model, encoder_vector, decoder_vector)








        # 8. Fine-tune to maintain model quality
        print("\n--- Fine-tuning watermarked model ---")
        watermarked_model.train()
        optimizer = torch.optim.AdamW(watermarked_model.parameters(), lr=self.config["lr"])
        mse_loss = nn.MSELoss()

        for epoch in range(self.config["epochs"]):
            encoder_vector = self.get_encoder_parameters(encoder)

            watermarked_model=self._embed_params_into_model(watermarked_model, encoder_vector, embedded_positions)
            pbar = tqdm(dataloader, desc=f"Finetune Epoch {epoch+1}")
            for clean_images, _ in pbar:



                clean_images = clean_images.to(self.device)

                optimizer.zero_grad()
                recon, mu, logvar = watermarked_model(clean_images)

                # Reconstruction loss
                l_recon = F.mse_loss(recon, clean_images)

                # KL divergence
                l_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

                # VAE total task loss
                loss = l_recon + self.config["beta_kl"] * l_kl

                loss.backward()
                optimizer.step()

                pbar.set_postfix(Loss=loss.item())
            reconstructed_autoencoder, avg_mse=self._extract_and_evaluate(watermarked_model, embedded_positions, decoder, dataloader)
            print(f"Epoch {epoch+1} - Finetune Loss: {loss.item():.4f} | Extracted Autoencoder MSE: {avg_mse:.4f}")
            if avg_mse <  self.config["mse_threshold"]:
                print(f"MSE= {avg_mse} is acceptable, stopping fine-tuning to preserve watermark integrity.")
                break
            else:
                print(f"MSE= {avg_mse} is too high, continuing fine-tuning to improve autoencoder reconstruction.")
                autoencoder = self._train_autoencoder(reconstructed_autoencoder, dataloader, self.config["ae_epochs"])
                encoder = autoencoder.encoder


        # Save keys
        self.saved_keys = {
            "autoencoder": reconstructed_autoencoder,
            "encoder": reconstructed_autoencoder.encoder,
            "decoder": reconstructed_autoencoder.decoder,
            # "encoder_vector": encoder_vector,
            # "decoder_vector": decoder_vector,
            "embedded_positions": embedded_positions,
            "layers_info": layers_info,
            "watermarked_model": watermarked_model,
        }
        torch.save(self.saved_keys, "hufu_VAE_model_checkpoint.pt")
        return watermarked_model

    # def extract(self, model=None):
    #     """
    #     Extracts the embedded encoder from a suspect model.
    #     """
    #     if model is None:
    #         model = self.saved_keys["watermarked_model"]
    #
    #     decoder_vector = self.saved_keys["decoder_vector"]
    #     embedded_positions = self.saved_keys["embedded_positions"]
    #     encoder_vector = self.saved_keys["encoder_vector"]
    #
    #     # Get model parameters
    #     param_vector, _ = self._get_conv_params(model)
    #
    #     # Extract encoder parameters
    #     extracted_encoder = torch.zeros_like(encoder_vector[:len(embedded_positions)])
    #
    #     with torch.no_grad():
    #         for i, pos in enumerate(embedded_positions):
    #             if pos < len(param_vector) and i < len(extracted_encoder):
    #                 extracted_encoder[i] = param_vector[pos]
    #
    #     # Compute correlation/similarity with original encoder
    #     original_encoder = encoder_vector[:len(embedded_positions)].to(self.device)
    #     extracted_encoder = extracted_encoder.to(self.device)
    #
    #     mse = F.mse_loss(extracted_encoder, original_encoder).item()
    #     correlation = F.cosine_similarity(
    #         extracted_encoder.unsqueeze(0),
    #         original_encoder.unsqueeze(0)
    #     ).item()
    #
    #     print(f"Extraction MSE: {mse:.4f}")
    #     print(f"Correlation with original: {correlation:.4f}")
    #
    #     return mse, correlation, extracted_encoder


    def extract(self, model=None, dataloader=None):
        """
        Extracts the embedded encoder from a suspect model and tests its functionality
        by reconstructing images using the owner's decoder.
        """
        if model is None:
            model = self.saved_keys["watermarked_model"]

        decoder = self.saved_keys["decoder"]
        embedded_positions = self.saved_keys["embedded_positions"]



        reconstructed_autoencoder, avg_mse=self._extract_and_evaluate(model, embedded_positions, decoder, dataloader)
        extracted_encoder = reconstructed_autoencoder.encoder


        return avg_mse, extracted_encoder





In [8]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from torchvision import transforms

import os
import torch
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load from Hugging Face (no Google Drive issues)
print("Loading dataset...")


from datasets import load_from_disk
hf_dataset = load_from_disk("celeba_local")
# os.makedirs("./celeba_images/all", exist_ok=True)
# for i, item in enumerate(hf_dataset):
#     item['image'].save(f"./celeba_images/all/{i:06d}.jpg")
# del hf_dataset


class CelebAWrapper(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        if self.transform:
            image = self.transform(image)
        return image, 0

dataset = CelebAWrapper(hf_dataset, transform)
# del hf_dataset
# gc.collect()
# dataset = datasets.ImageFolder("./celeba_images", transform=transform)
print("Dataset loaded!")

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
print("loader loaded!")


Loading dataset...
Dataset loaded!
loader loaded!


In [9]:
# --- EXEMPLE D'EXECUTION ---
#load model
latent_dim = 200

# Initialize the model
model = VAE(latent_dim=latent_dim)

# Load the trained weights
model_path = "./vae_celeba_latent_200_epochs_10_batch_64_subset_80000.pth"
model.load_state_dict(torch.load(model_path))
model.to(device)

# 2. Embedding HufuNet
hufu_defense = HufuVAE(model=model, device=device)
watermarked_model = hufu_defense.embed(dataloader)

--- Demarrage Embedding HufuNet ---
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:27<00:00, 36.06it/s, Loss=0.00234]


Epoch 1 - Avg Loss: 0.0052


Embedding: 100%|██████████| 37728/37728 [00:00<00:00, 51254.76it/s]



--- Fine-tuning watermarked model ---


Finetune Epoch 1: 100%|██████████| 3166/3166 [02:02<00:00, 25.82it/s, Loss=0.338]


Reconstructed autoencoder MSE: 0.0625
Epoch 1 - Finetune Loss: 0.3376 | Extracted Autoencoder MSE: 0.0625
MSE= 0.06253124102950096 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.42it/s, Loss=0.00157]


Epoch 1 - Avg Loss: 0.0019


Finetune Epoch 2: 100%|██████████| 3166/3166 [01:24<00:00, 37.49it/s, Loss=0.316]


Reconstructed autoencoder MSE: 0.1341
Epoch 2 - Finetune Loss: 0.3161 | Extracted Autoencoder MSE: 0.1341
MSE= 0.13408446907997132 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:17<00:00, 40.63it/s, Loss=0.00125] 


Epoch 1 - Avg Loss: 0.0016


Finetune Epoch 3: 100%|██████████| 3166/3166 [01:23<00:00, 37.92it/s, Loss=0.297]


Reconstructed autoencoder MSE: 0.0607
Epoch 3 - Finetune Loss: 0.2969 | Extracted Autoencoder MSE: 0.0607
MSE= 0.06066616475582123 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 40.07it/s, Loss=0.00103] 


Epoch 1 - Avg Loss: 0.0012


Finetune Epoch 4: 100%|██████████| 3166/3166 [01:23<00:00, 37.80it/s, Loss=0.318]


Reconstructed autoencoder MSE: 0.0347
Epoch 4 - Finetune Loss: 0.3184 | Extracted Autoencoder MSE: 0.0347
MSE= 0.03467031754553318 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.17it/s, Loss=0.00102] 


Epoch 1 - Avg Loss: 0.0011


Finetune Epoch 5: 100%|██████████| 3166/3166 [01:24<00:00, 37.67it/s, Loss=0.261]


Reconstructed autoencoder MSE: 0.0599
Epoch 5 - Finetune Loss: 0.2612 | Extracted Autoencoder MSE: 0.0599
MSE= 0.05985407680273056 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 39.92it/s, Loss=0.000853]


Epoch 1 - Avg Loss: 0.0011


Finetune Epoch 6: 100%|██████████| 3166/3166 [01:24<00:00, 37.67it/s, Loss=0.319]


Reconstructed autoencoder MSE: 0.0734
Epoch 6 - Finetune Loss: 0.3189 | Extracted Autoencoder MSE: 0.0734
MSE= 0.07336710318922997 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:17<00:00, 40.66it/s, Loss=0.000881]


Epoch 1 - Avg Loss: 0.0010


Finetune Epoch 7: 100%|██████████| 3166/3166 [01:23<00:00, 37.96it/s, Loss=0.332]


Reconstructed autoencoder MSE: 0.0238
Epoch 7 - Finetune Loss: 0.3318 | Extracted Autoencoder MSE: 0.0238
MSE= 0.02378745935857296 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 39.75it/s, Loss=0.000972]


Epoch 1 - Avg Loss: 0.0009


Finetune Epoch 8: 100%|██████████| 3166/3166 [01:25<00:00, 37.14it/s, Loss=0.326]


Reconstructed autoencoder MSE: 0.0581
Epoch 8 - Finetune Loss: 0.3263 | Extracted Autoencoder MSE: 0.0581
MSE= 0.0580789964646101 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:20<00:00, 39.52it/s, Loss=0.000754]


Epoch 1 - Avg Loss: 0.0009


Finetune Epoch 9: 100%|██████████| 3166/3166 [01:24<00:00, 37.29it/s, Loss=0.291]


Reconstructed autoencoder MSE: 0.0447
Epoch 9 - Finetune Loss: 0.2907 | Extracted Autoencoder MSE: 0.0447
MSE= 0.04472903162240982 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.08it/s, Loss=0.000718]


Epoch 1 - Avg Loss: 0.0009


Finetune Epoch 10: 100%|██████████| 3166/3166 [01:25<00:00, 37.17it/s, Loss=0.293]


Reconstructed autoencoder MSE: 0.0123
Epoch 10 - Finetune Loss: 0.2927 | Extracted Autoencoder MSE: 0.0123
MSE= 0.012346486002206803 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:20<00:00, 39.25it/s, Loss=0.000715]


Epoch 1 - Avg Loss: 0.0009


Finetune Epoch 11: 100%|██████████| 3166/3166 [01:24<00:00, 37.53it/s, Loss=0.33] 


Reconstructed autoencoder MSE: 0.0243
Epoch 11 - Finetune Loss: 0.3302 | Extracted Autoencoder MSE: 0.0243
MSE= 0.0242686090990901 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.11it/s, Loss=0.000733]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 12: 100%|██████████| 3166/3166 [01:25<00:00, 37.20it/s, Loss=0.298]


Reconstructed autoencoder MSE: 0.0204
Epoch 12 - Finetune Loss: 0.2976 | Extracted Autoencoder MSE: 0.0204
MSE= 0.02036541383713484 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:20<00:00, 39.52it/s, Loss=0.000713]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 13: 100%|██████████| 3166/3166 [01:25<00:00, 37.17it/s, Loss=0.307]


Reconstructed autoencoder MSE: 0.0308
Epoch 13 - Finetune Loss: 0.3074 | Extracted Autoencoder MSE: 0.0308
MSE= 0.030797526985406876 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 39.65it/s, Loss=0.00066] 


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 14: 100%|██████████| 3166/3166 [01:25<00:00, 37.15it/s, Loss=0.295]


Reconstructed autoencoder MSE: 0.0225
Epoch 14 - Finetune Loss: 0.2948 | Extracted Autoencoder MSE: 0.0225
MSE= 0.0224681967869401 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.16it/s, Loss=0.000699]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 15: 100%|██████████| 3166/3166 [01:23<00:00, 38.00it/s, Loss=0.293]


Reconstructed autoencoder MSE: 0.0250
Epoch 15 - Finetune Loss: 0.2926 | Extracted Autoencoder MSE: 0.0250
MSE= 0.025029807537794112 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.43it/s, Loss=0.000672]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 16: 100%|██████████| 3166/3166 [01:23<00:00, 37.88it/s, Loss=0.3]  


Reconstructed autoencoder MSE: 0.0287
Epoch 16 - Finetune Loss: 0.3005 | Extracted Autoencoder MSE: 0.0287
MSE= 0.02872729189693928 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 39.74it/s, Loss=0.000632]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 17: 100%|██████████| 3166/3166 [01:22<00:00, 38.27it/s, Loss=0.35] 


Reconstructed autoencoder MSE: 0.0499
Epoch 17 - Finetune Loss: 0.3499 | Extracted Autoencoder MSE: 0.0499
MSE= 0.0498591635376215 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:17<00:00, 40.87it/s, Loss=0.000712]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 18: 100%|██████████| 3166/3166 [01:23<00:00, 37.91it/s, Loss=0.327]


Reconstructed autoencoder MSE: 0.0103
Epoch 18 - Finetune Loss: 0.3271 | Extracted Autoencoder MSE: 0.0103
MSE= 0.010330616123974323 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:18<00:00, 40.08it/s, Loss=0.000583]


Epoch 1 - Avg Loss: 0.0007


Finetune Epoch 19: 100%|██████████| 3166/3166 [01:23<00:00, 38.11it/s, Loss=0.31] 


Reconstructed autoencoder MSE: 0.0960
Epoch 19 - Finetune Loss: 0.3097 | Extracted Autoencoder MSE: 0.0960
MSE= 0.09601198583841324 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:17<00:00, 40.88it/s, Loss=0.000677]


Epoch 1 - Avg Loss: 0.0009


Finetune Epoch 20: 100%|██████████| 3166/3166 [01:22<00:00, 38.15it/s, Loss=0.304]


Reconstructed autoencoder MSE: 0.0129
Epoch 20 - Finetune Loss: 0.3043 | Extracted Autoencoder MSE: 0.0129
MSE= 0.012871145736426115 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:19<00:00, 40.01it/s, Loss=0.000835]


Epoch 1 - Avg Loss: 0.0008


Finetune Epoch 21: 100%|██████████| 3166/3166 [01:23<00:00, 37.75it/s, Loss=0.283]


Reconstructed autoencoder MSE: 0.0174
Epoch 21 - Finetune Loss: 0.2829 | Extracted Autoencoder MSE: 0.0174
MSE= 0.01736978329718113 is too high, continuing fine-tuning to improve autoencoder reconstruction.
--- Pre-training HufuNet Autoencoder ---


AE Epoch 1/1: 100%|██████████| 3166/3166 [01:17<00:00, 41.04it/s, Loss=0.000688]


Epoch 1 - Avg Loss: 0.0007


Finetune Epoch 22: 100%|██████████| 3166/3166 [01:22<00:00, 38.19it/s, Loss=0.289]


Reconstructed autoencoder MSE: 0.0084
Epoch 22 - Finetune Loss: 0.2886 | Extracted Autoencoder MSE: 0.0084
MSE= 0.008433844894170761 is acceptable, stopping fine-tuning to preserve watermark integrity.


In [30]:
# --- Fonction de Distillation (Attaque) ---

def run_distillation_attack_hufu(hufu_obj, dataloader, epochs=5, lr=1e-4):
    """
    Tente de transferer la fonctionnalite du modele HufuNet vers un modele vierge.
    Verifie si le watermark survit.
    """
    device = hufu_obj.device
    checkpoint = torch.load("hufu_VAE_model_checkpoint.pt", weights_only=False)

    # 1. Teacher (Gele)
    # teacher = checkpoint["watermarked_model"]
    teacher=hufu_obj.saved_keys["watermarked_model"]
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad = False

    # 2. Student (Vierge - Meme architecture)
    print("\n--- Initialisation du Student ---")
    student=VAE(latent_dim=latent_dim)
    student.load_state_dict(torch.load(model_path))
    student.to(device)
    student.train()


    mse_teacher, _ = hufu_obj.extract(teacher, dataloader)
    print("\n[Check] Teacher: mse_hufu_teacher: {:.4f}".format(mse_teacher))

    mse_student, _ = hufu_obj.extract(student, dataloader)
    print("\n[Check] Student (before): mse_hufu_student: {:.4f}".format(mse_student))

    optimizer = AdamW(student.parameters(), lr=lr)

    history = {"loss": [], "MSE": []}

    print(f"\n--- Distillation HufuNet ({epochs} epochs) ---")

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        running_loss = 0.0

        for clean_images, _ in pbar:
            clean_images = clean_images.to(device)

            with torch.no_grad():
                target_pred,_,_ = teacher(clean_images)

            student_pred,_,_ = student(clean_images)

            loss = F.mse_loss(student_pred, target_pred)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(Loss=loss.item())

        # Check correlation
        avg_mse, _ = hufu_obj.extract(student, dataloader)
        history["MSE"].append(avg_mse)
        history["loss"].append(running_loss / len(dataloader))

        print(f"Fin Epoch {epoch+1} | Loss: {history['loss'][-1]:.4f} | MSE_hufu_autoencoder: {avg_mse:.4f}")


    return student, history


In [31]:
# 5. Attaque par Distillation
student_res, stats = run_distillation_attack_hufu(hufu_defense, dataloader, epochs=100)


--- Initialisation du Student ---
Reconstructed autoencoder MSE: 0.0234

[Check] Teacher: mse_hufu_teacher: 0.0234
Reconstructed autoencoder MSE: 0.1601

[Check] Student (before): mse_hufu_student: 0.1601

--- Distillation HufuNet (100 epochs) ---


Epoch 1: 100%|██████████| 3166/3166 [01:25<00:00, 37.04it/s, Loss=0.00486]


Reconstructed autoencoder MSE: 0.1556
Fin Epoch 1 | Loss: 0.0082 | MSE_hufu_autoencoder: 0.1556


Epoch 2: 100%|██████████| 3166/3166 [01:25<00:00, 37.05it/s, Loss=0.00706]


Reconstructed autoencoder MSE: 0.1550
Fin Epoch 2 | Loss: 0.0069 | MSE_hufu_autoencoder: 0.1550


Epoch 3: 100%|██████████| 3166/3166 [01:25<00:00, 36.93it/s, Loss=0.0078] 


Reconstructed autoencoder MSE: 0.1602
Fin Epoch 3 | Loss: 0.0068 | MSE_hufu_autoencoder: 0.1602


Epoch 4: 100%|██████████| 3166/3166 [01:25<00:00, 36.85it/s, Loss=0.00752]


Reconstructed autoencoder MSE: 0.1618
Fin Epoch 4 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1618


Epoch 5: 100%|██████████| 3166/3166 [01:24<00:00, 37.36it/s, Loss=0.00602]


Reconstructed autoencoder MSE: 0.1595
Fin Epoch 5 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1595


Epoch 6: 100%|██████████| 3166/3166 [01:24<00:00, 37.41it/s, Loss=0.00702]


Reconstructed autoencoder MSE: 0.1596
Fin Epoch 6 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1596


Epoch 7: 100%|██████████| 3166/3166 [01:25<00:00, 37.13it/s, Loss=0.00697]


Reconstructed autoencoder MSE: 0.1611
Fin Epoch 7 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1611


Epoch 8: 100%|██████████| 3166/3166 [01:24<00:00, 37.39it/s, Loss=0.00889]


Reconstructed autoencoder MSE: 0.1605
Fin Epoch 8 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1605


Epoch 9: 100%|██████████| 3166/3166 [01:24<00:00, 37.57it/s, Loss=0.00736]


Reconstructed autoencoder MSE: 0.1604
Fin Epoch 9 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1604


Epoch 10: 100%|██████████| 3166/3166 [01:24<00:00, 37.45it/s, Loss=0.00516]


Reconstructed autoencoder MSE: 0.1532
Fin Epoch 10 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1532


Epoch 11: 100%|██████████| 3166/3166 [01:24<00:00, 37.44it/s, Loss=0.00696]


Reconstructed autoencoder MSE: 0.1586
Fin Epoch 11 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1586


Epoch 12: 100%|██████████| 3166/3166 [01:24<00:00, 37.36it/s, Loss=0.00536]


Reconstructed autoencoder MSE: 0.1522
Fin Epoch 12 | Loss: 0.0067 | MSE_hufu_autoencoder: 0.1522


Epoch 13: 100%|██████████| 3166/3166 [01:30<00:00, 35.14it/s, Loss=0.00632]


Reconstructed autoencoder MSE: 0.1598
Fin Epoch 13 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1598


Epoch 14: 100%|██████████| 3166/3166 [01:23<00:00, 37.78it/s, Loss=0.00672]


Reconstructed autoencoder MSE: 0.1541
Fin Epoch 14 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1541


Epoch 15: 100%|██████████| 3166/3166 [01:24<00:00, 37.52it/s, Loss=0.00448]


Reconstructed autoencoder MSE: 0.1583
Fin Epoch 15 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1583


Epoch 16: 100%|██████████| 3166/3166 [01:23<00:00, 37.74it/s, Loss=0.00487]


Reconstructed autoencoder MSE: 0.1600
Fin Epoch 16 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1600


Epoch 17: 100%|██████████| 3166/3166 [01:27<00:00, 36.15it/s, Loss=0.00637]


Reconstructed autoencoder MSE: 0.1520
Fin Epoch 17 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1520


Epoch 18: 100%|██████████| 3166/3166 [01:25<00:00, 37.06it/s, Loss=0.00714]


Reconstructed autoencoder MSE: 0.1587
Fin Epoch 18 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1587


Epoch 19: 100%|██████████| 3166/3166 [01:25<00:00, 36.96it/s, Loss=0.00779]


Reconstructed autoencoder MSE: 0.1617
Fin Epoch 19 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1617


Epoch 20: 100%|██████████| 3166/3166 [01:25<00:00, 37.18it/s, Loss=0.00836]


Reconstructed autoencoder MSE: 0.1579
Fin Epoch 20 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1579


Epoch 21: 100%|██████████| 3166/3166 [01:25<00:00, 37.13it/s, Loss=0.00701]


Reconstructed autoencoder MSE: 0.1630
Fin Epoch 21 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1630


Epoch 22: 100%|██████████| 3166/3166 [01:24<00:00, 37.31it/s, Loss=0.00742]


Reconstructed autoencoder MSE: 0.1630
Fin Epoch 22 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1630


Epoch 23: 100%|██████████| 3166/3166 [01:25<00:00, 37.18it/s, Loss=0.00892]


Reconstructed autoencoder MSE: 0.1551
Fin Epoch 23 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1551


Epoch 24: 100%|██████████| 3166/3166 [01:24<00:00, 37.25it/s, Loss=0.00555]


Reconstructed autoencoder MSE: 0.1635
Fin Epoch 24 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1635


Epoch 25: 100%|██████████| 3166/3166 [01:25<00:00, 37.09it/s, Loss=0.00606]


Reconstructed autoencoder MSE: 0.1652
Fin Epoch 25 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1652


Epoch 26: 100%|██████████| 3166/3166 [01:25<00:00, 37.12it/s, Loss=0.00696]


Reconstructed autoencoder MSE: 0.1588
Fin Epoch 26 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1588


Epoch 27: 100%|██████████| 3166/3166 [01:23<00:00, 37.78it/s, Loss=0.00518]


Reconstructed autoencoder MSE: 0.1703
Fin Epoch 27 | Loss: 0.0066 | MSE_hufu_autoencoder: 0.1703


Epoch 28:  45%|████▌     | 1429/3166 [00:37<00:45, 38.06it/s, Loss=0.00679]


KeyboardInterrupt: 

In [None]:
# 6. Visualisation des resultats
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(stats["loss"])
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Distillation Loss")

ax2.plot(stats["correlation"])
ax2.axhline(y=0.5, color='r', linestyle='--', label='Threshold (0.5)')
ax2.axhline(y=0.0, color='g', linestyle='--', label='No correlation (0.0)')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Correlation")
ax2.set_title("Encoder Correlation During Distillation")
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# 7. Visualize Autoencoder Reconstruction
import matplotlib.pyplot as plt

autoencoder = hufu_defense.saved_keys["autoencoder"]
autoencoder.eval()

# Get some test images
test_images, _ = next(iter(dataloader))
test_images = test_images[:8].to(hufu_defense.device)
test_images_norm = (test_images + 1) / 2

with torch.no_grad():
    _, reconstructed = autoencoder(test_images_norm)

# Plot
fig, axes = plt.subplots(2, 8, figsize=(16, 4))

for i in range(8):
    # Original
    img = test_images_norm[i].cpu().permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original')

    # Reconstructed
    rec = reconstructed[i].cpu().permute(1, 2, 0).numpy()
    axes[1, i].imshow(rec)
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed')

plt.tight_layout()
plt.show()