In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import psutil


class VAE(nn.Module):
    def __init__(self, latent_dim=200):
        super(VAE, self).__init__()

        # Encoder: 4 Conv2d layers with stride=2, kernel=4, padding=1
        # Input: (3, 64, 64) -> (32, 32, 32) -> (64, 16, 16) -> (128, 8, 8) -> (256, 4, 4)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),    # 0
            nn.ReLU(),                                                # 1
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                # 3
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 4
            nn.ReLU(),                                                # 5
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 6
            nn.ReLU(),                                                # 7
        )

        # 256 * 4 * 4 = 4096
        self.fc_mu = nn.Linear(4096, latent_dim)
        self.fc_logvar = nn.Linear(4096, latent_dim)

        # Decoder input
        self.decoder_input = nn.Linear(latent_dim, 4096)

        # Decoder: 4 ConvTranspose2d layers
        # (256, 4, 4) -> (128, 8, 8) -> (64, 16, 16) -> (32, 32, 32) -> (3, 64, 64)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 0
            nn.ReLU(),                                                          # 1
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 2
            nn.ReLU(),                                                          # 3
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 4
            nn.ReLU(),                                                          # 5
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),     # 6
            nn.Sigmoid(),                                                       # 7
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten to (batch, 4096)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to (batch, 256, 4, 4)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from copy import deepcopy
from torch.optim import AdamW

# --- Riga Extractor Network ---

class RigaExtractor(nn.Module):
    """Extractor network that projects weights to watermark space."""
    def __init__(self, weight_size, watermark_len):
        super().__init__()
        self.fc1 = nn.Linear(weight_size, 100, bias=False)
        self.fc2 = nn.Linear(100, watermark_len, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.sig(out)
        out = self.fc2(out)
        out = self.sig(out)
        return out


# --- Riga Detector Network ---

class RigaDetector(nn.Module):
    """Detector network that distinguishes watermarked from non-watermarked weights."""
    def __init__(self, weight_size):
        super().__init__()
        self.fc1 = nn.Linear(weight_size, 100, bias=False)
        self.fc2 = nn.Linear(100, 1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.sig(out)
        out = self.fc2(out)
        out = self.sig(out)
        return out


# --- Classe Principale Riga DDPM ---

class RigaVAE:
    def __init__(self, model, device="cuda"):
        self.device = device
        self.model = model


        # Configuration par defaut
        self.config = {
            "layer_name": "decoder.6.weight",  # Couche cible
            "watermark_len": 16,
            "lr": 1e-4,
            "lr_det": 1e-3,        # Learning rate for detector
            "lambda_1": 1.0,       # Watermark loss weight
            "lambda_2": 1.0,       # Adversarial loss weight
            "epochs": 10,
            "clip_value": 0.01,
            "beta_kl":1.0# Weight clipping for detector (WGAN-style)
        }

        self.saved_keys = {}

    def _get_target_weights(self, model):
        """Recupere le tenseur des poids de la couche cible."""
        for name, param in model.named_parameters():
            if name == self.config["layer_name"]:
                return param
        raise ValueError(f"Parametre {self.config['layer_name']} introuvable.")

    def embed(self, dataloader):
        """
        Incorpore la marque Riga pendant le finetuning.
        Utilise un entrainement adversarial avec un detecteur.
        """
        print(f"--- Demarrage Embedding Riga ({self.config['layer_name']}) ---")

        watermarked_model = self.model
        watermarked_model.train()

        # 1. Get initial weights
        target_weights = self._get_target_weights(watermarked_model)
        with torch.no_grad():
            init_w = torch.flatten(target_weights.mean(dim=0)).clone()
            weight_size = len(init_w)

        print(f"Dimension vecteur poids : {weight_size} | Watermark : {self.config['watermark_len']} bits")

        # 2. Generate watermarks
        watermark_target = torch.randint(0, 2, (self.config["watermark_len"],)).float().to(self.device)
        watermark_random = torch.randint(0, 2, (self.config["watermark_len"],)).float().to(self.device)

        # 3. Initialize networks
        extractor = RigaExtractor(weight_size, self.config["watermark_len"]).to(self.device)
        detector = RigaDetector(weight_size).to(self.device)

        # 4. Optimizers
        optimizer = torch.optim.AdamW([
            {'params': watermarked_model.parameters(), 'lr': self.config["lr"]},
            {'params': extractor.parameters(), 'lr': self.config["lr"], 'betas': (0.5, 0.999)}
        ])
        optimizer_det = torch.optim.Adam(
            detector.parameters(),
            lr=self.config["lr_det"],
            betas=(0.5, 0.999)
        )

        mse_loss = nn.MSELoss()
        bce_loss = nn.BCELoss()

        # 5. Training loop
        for epoch in range(self.config["epochs"]):
            pbar = tqdm(dataloader)
            for clean_images, _ in pbar:
                clean_images = clean_images.to(self.device)


                # Get current weights
                current_weights = self._get_target_weights(watermarked_model)
                w = torch.flatten(current_weights.mean(dim=0))

                # Sort weights for detector (as in original Riga)
                w_sorted = torch.sort(w.detach())[0]
                init_w_sorted = torch.sort(init_w.detach())[0]

                # === A. Train Detector ===
                optimizer_det.zero_grad()

                # Detector should output 1 for non-watermarked, 0 for watermarked
                out_det_wat = detector(w_sorted)
                out_det_non = detector(init_w_sorted)

                loss_det_non = bce_loss(out_det_non, torch.ones(1).to(self.device))
                loss_det_wat = bce_loss(out_det_wat, torch.zeros(1).to(self.device))
                loss_det = loss_det_non + loss_det_wat

                loss_det.backward(retain_graph=True)
                optimizer_det.step()

                # Clip detector weights (WGAN-style)
                with torch.no_grad():
                    for param in detector.parameters():
                        param.clamp_(-self.config["clip_value"], self.config["clip_value"])

                # === B. Train VAE + Extractor ===

                optimizer.zero_grad()


                recon, mu, logvar = watermarked_model(clean_images)

                # Reconstruction loss
                l_recon = F.mse_loss(recon, clean_images)

                # KL divergence
                l_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

                # VAE total task loss
                l_main = l_recon + self.config["beta_kl"] * l_kl


                # Watermark extraction loss
                w_current = torch.flatten(self._get_target_weights(watermarked_model).mean(dim=0))
                out_watermark = extractor(w_current)
                init_out_watermark = extractor(init_w.detach())

                loss_ext_wat = bce_loss(out_watermark, watermark_target)
                loss_ext_init = bce_loss(init_out_watermark, watermark_random)

                # Adversarial loss (fool detector - make it think watermarked is non-watermarked)
                w_sorted_current = torch.sort(w_current.detach())[0]
                out_det_current = detector(w_sorted_current)
                loss_adv = bce_loss(out_det_current, torch.ones(1).to(self.device))

                # Total loss
                l_total = (l_main +
                          self.config["lambda_1"] * (loss_ext_wat + loss_ext_init) -
                          self.config["lambda_2"] * loss_adv)

                l_total.backward()
                optimizer.step()

                # Metrics
                ber = self._compute_ber(out_watermark, watermark_target)
                pbar.set_description(
                    f"Epoch {epoch+1} | L_Main: {l_main:.3f} | L_Ext: {loss_ext_wat:.3f} | "
                    f"L_Det: {loss_det:.3f} | BER: {ber:.2f}"
                )

                # if ber == 0.0 and loss_ext_wat.item() < 0.01:
                #     print("Convergence atteinte !")
                #     break
            # if ber == 0.0:
            #     break

        # Save keys
        self.saved_keys = {
            "watermark_target": watermark_target,
            "watermarked_model": watermarked_model,
            "extractor": extractor,
            "detector": detector,
            "init_w": init_w,
        }
        torch.save(self.saved_keys, "Riga_VAE_model_checkpoint.pt")
        return watermarked_model

    def extract(self, model=None):
        """
        Extrait la marque d'un modele suspect via le reseau extracteur.
        """
        if model is None:
            print('suspected is none')
            model = self.saved_keys["watermarked_model"]

        extractor = self.saved_keys["extractor"]
        watermark_target = self.saved_keys["watermark_target"]

        extractor.eval()

        # Get weights
        try:
            target_weights = self._get_target_weights(model)
        except ValueError:
            print("Couche cible introuvable dans le modele suspect.")
            return 1.0, None

        # Extract watermark
        with torch.no_grad():
            w = torch.flatten(target_weights.mean(dim=0))
            pred_wm = extractor(w)
            ber = self._compute_ber(pred_wm, watermark_target)

        print(f"BER Extrait : {ber:.2f}")
        return ber, pred_wm

    def detect(self, model=None):
        """
        Detecte si un modele est watermarke via le reseau detecteur.
        Returns: detection score (0 = watermarked, 1 = non-watermarked)
        """
        if model is None:
            print("suspected is none")
            model = self.saved_keys["watermarked_model"]

        detector = self.saved_keys["detector"]
        detector.eval()

        try:
            target_weights = self._get_target_weights(model)
            print("ok")
        except ValueError:
            print("error")
            return 0.5

        with torch.no_grad():
            w = torch.flatten(target_weights.mean(dim=0))
            w_sorted = torch.sort(w)[0]
            detection_score = detector(w_sorted).item()

        print(f"Detection score: {detection_score:.4f} (0=watermarked, 1=non-watermarked)")
        return detection_score

    @staticmethod
    def _compute_ber(pred, target):
        return ((pred > 0.5).float() != target).float().mean().item()


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from torchvision import transforms

import os
import torch
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load from Hugging Face (no Google Drive issues)
print("Loading dataset...")


from datasets import load_from_disk
hf_dataset = load_from_disk("celeba_local")
# os.makedirs("./celeba_images/all", exist_ok=True)
# for i, item in enumerate(hf_dataset):
#     item['image'].save(f"./celeba_images/all/{i:06d}.jpg")
# del hf_dataset


class CelebAWrapper(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        if self.transform:
            image = self.transform(image)
        return image, 0

dataset = CelebAWrapper(hf_dataset, transform)
# del hf_dataset
# gc.collect()
# dataset = datasets.ImageFolder("./celeba_images", transform=transform)
print("Dataset loaded!")

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
print("loader loaded!")


Loading dataset...
Dataset loaded!
loader loaded!


In [4]:
# --- EXEMPLE D'EXECUTION ---

# 1. Setup Data
latent_dim = 200

# Initialize the model
model = VAE(latent_dim=latent_dim)

# Load the trained weights
model_path = "./vae_celeba_latent_200_epochs_10_batch_64_subset_80000.pth"
model.load_state_dict(torch.load(model_path))
model.to(device)

# 2. Embedding Riga
riga_defense = RigaVAE(model)
watermarked_model = riga_defense.embed(dataloader)

--- Demarrage Embedding Riga (decoder.6.weight) ---
Dimension vecteur poids : 48 | Watermark : 16 bits


Epoch 1 | L_Main: 0.329 | L_Ext: 0.008 | L_Det: 1.378 | BER: 0.00: 100%|██████████| 3166/3166 [01:30<00:00, 34.90it/s]
Epoch 2 | L_Main: 0.336 | L_Ext: 0.001 | L_Det: 1.376 | BER: 0.00: 100%|██████████| 3166/3166 [02:06<00:00, 25.10it/s]
Epoch 3 | L_Main: 0.309 | L_Ext: 0.000 | L_Det: 1.375 | BER: 0.00: 100%|██████████| 3166/3166 [03:26<00:00, 15.35it/s]
Epoch 4 | L_Main: 0.314 | L_Ext: 0.000 | L_Det: 1.374 | BER: 0.00: 100%|██████████| 3166/3166 [03:52<00:00, 13.60it/s]
Epoch 5 | L_Main: 0.299 | L_Ext: 0.000 | L_Det: 1.374 | BER: 0.00: 100%|██████████| 3166/3166 [04:52<00:00, 10.81it/s]
Epoch 6 | L_Main: 0.280 | L_Ext: 0.000 | L_Det: 1.374 | BER: 0.00: 100%|██████████| 3166/3166 [01:26<00:00, 36.42it/s]
Epoch 7 | L_Main: 0.294 | L_Ext: 0.000 | L_Det: 1.373 | BER: 0.00: 100%|██████████| 3166/3166 [01:37<00:00, 32.55it/s]
Epoch 8 | L_Main: 0.355 | L_Ext: 0.000 | L_Det: 1.373 | BER: 0.00: 100%|██████████| 3166/3166 [02:15<00:00, 23.34it/s]
Epoch 9 | L_Main: 0.289 | L_Ext: 0.000 | L_Det: 

In [5]:
# 3. Extraction (Test immediat)
ber, _ = riga_defense.extract(watermarked_model)
print(f"\nResultat final - BER: {ber:.2f}")

# 4. Detection
print("\n--- Detection Test ---")
print("Watermarked model:")
riga_defense.detect(watermarked_model)

# Test with fresh model
print("\nFresh (non-watermarked) model:")
fresh_model = VAE(latent_dim=latent_dim)
fresh_model.load_state_dict(torch.load(model_path))
fresh_model.to(device)

riga_defense.detect(fresh_model)

BER Extrait : 0.00

Resultat final - BER: 0.00

--- Detection Test ---
Watermarked model:
ok
Detection score: 0.4821 (0=watermarked, 1=non-watermarked)

Fresh (non-watermarked) model:
ok
Detection score: 0.4891 (0=watermarked, 1=non-watermarked)


0.48905032873153687

In [12]:
# --- Fonction de Distillation (Attaque) ---

def run_distillation_attack_riga(riga_obj, dataloader, epochs=5, lr=1e-4):
    """
    Tente de transferer la fonctionnalite du modele Riga vers un modele vierge.
    Verifie si la marque et la detection survivent.
    """
    device = riga_obj.device
    checkpoint = torch.load("Riga_VAE_model_checkpoint.pt", weights_only=False)
    # 1. Teacher (Gele)
    teacher = checkpoint["watermarked_model"]
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad = False

    # 2. Student (Vierge - Meme architecture)
    print("\n--- Initialisation du Student ---")
    student=VAE(latent_dim=latent_dim)
    student.load_state_dict(torch.load(model_path))
    student.to(device)
    student.train()

    teacher_ber, _ = riga_obj.extract(teacher)
    student_ber, _ = riga_obj.extract(student)
    print(f"[Check] BER Teacher: {teacher_ber:.2f}")
    print(f"[Check] BER Student (Avant): {student_ber:.2f}")

    print("\n[Check] Detection scores:")
    print("Teacher:", end=" ")
    riga_obj.detect(teacher)
    print("Student (before):", end=" ")
    riga_obj.detect(student)

    optimizer = AdamW(student.parameters(), lr=lr)

    history = {"loss": [], "ber": [], "detection": []}

    print(f"\n--- Distillation Riga ({epochs} epochs) ---")
    a=0
    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        running_loss = 0.0

        for clean_images, _ in pbar:
            clean_images = clean_images.to(device)


            with torch.no_grad():
                target_pred,_,_ = teacher(clean_images)

            student_pred,_ ,_ = student(clean_images)

            loss = F.mse_loss(student_pred, target_pred)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(Loss=loss.item())

        # Check metrics
        current_ber, pred_wm = riga_obj.extract(student)
        detection_score = riga_obj.detect(student)

        history["ber"].append(current_ber)
        history["loss"].append(running_loss / len(dataloader))
        history["detection"].append(detection_score)

        err_wat = nn.BCELoss()(pred_wm, checkpoint["watermark_target"]).item() if pred_wm is not None else float('nan')
        print(f"Fin Epoch {epoch+1} | Loss: {history['loss'][-1]:.4f} | BER: {current_ber:.2f} | Detection: {detection_score:.4f}, err_wat: {nn.BCELoss()(pred_wm, checkpoint["watermark_target"]).item()}")
        if current_ber==0.0 and a>=1:
            print("✅ Marque récupérée avec succès par distillation !")
            break
        elif current_ber==0.0 and a<1 :
            a+=1
        else:
            a=0

    return student, history



In [13]:
# 5. Attaque par Distillation
student_res, stats = run_distillation_attack_riga(riga_defense, dataloader, epochs=100)


--- Initialisation du Student ---
BER Extrait : 0.00
BER Extrait : 0.44
[Check] BER Teacher: 0.00
[Check] BER Student (Avant): 0.44

[Check] Detection scores:
Teacher: ok
Detection score: 0.4821 (0=watermarked, 1=non-watermarked)
Student (before): ok
Detection score: 0.4891 (0=watermarked, 1=non-watermarked)

--- Distillation Riga (100 epochs) ---


Epoch 1: 100%|██████████| 3166/3166 [01:35<00:00, 33.29it/s, Loss=0.00805]


BER Extrait : 0.44
ok
Detection score: 0.4890 (0=watermarked, 1=non-watermarked)
Fin Epoch 1 | Loss: 0.0082 | BER: 0.44 | Detection: 0.4890, err_wat: 6.804908752441406


Epoch 2: 100%|██████████| 3166/3166 [01:36<00:00, 32.73it/s, Loss=0.00512]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 2 | Loss: 0.0070 | BER: 0.44 | Detection: 0.4889, err_wat: 6.696590423583984


Epoch 3: 100%|██████████| 3166/3166 [01:38<00:00, 32.26it/s, Loss=0.00715]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 3 | Loss: 0.0069 | BER: 0.44 | Detection: 0.4889, err_wat: 6.629148960113525


Epoch 4: 100%|██████████| 3166/3166 [01:38<00:00, 32.13it/s, Loss=0.00762]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 4 | Loss: 0.0069 | BER: 0.44 | Detection: 0.4889, err_wat: 6.610177993774414


Epoch 5: 100%|██████████| 3166/3166 [01:38<00:00, 32.30it/s, Loss=0.00679]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 5 | Loss: 0.0069 | BER: 0.44 | Detection: 0.4889, err_wat: 6.579821586608887


Epoch 6: 100%|██████████| 3166/3166 [01:37<00:00, 32.35it/s, Loss=0.0063] 


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 6 | Loss: 0.0069 | BER: 0.44 | Detection: 0.4889, err_wat: 6.5444183349609375


Epoch 7: 100%|██████████| 3166/3166 [01:27<00:00, 36.03it/s, Loss=0.006]  


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 7 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.542362213134766


Epoch 8: 100%|██████████| 3166/3166 [01:24<00:00, 37.34it/s, Loss=0.00669]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 8 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.514590263366699


Epoch 9: 100%|██████████| 3166/3166 [01:24<00:00, 37.34it/s, Loss=0.00618]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 9 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.478653907775879


Epoch 10: 100%|██████████| 3166/3166 [01:25<00:00, 37.24it/s, Loss=0.00548]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 10 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.499850273132324


Epoch 11: 100%|██████████| 3166/3166 [01:25<00:00, 37.07it/s, Loss=0.00655]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 11 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.504090309143066


Epoch 12: 100%|██████████| 3166/3166 [01:25<00:00, 37.08it/s, Loss=0.00514]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 12 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.463783264160156


Epoch 13: 100%|██████████| 3166/3166 [01:24<00:00, 37.33it/s, Loss=0.00482]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 13 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.446440696716309


Epoch 14: 100%|██████████| 3166/3166 [01:24<00:00, 37.56it/s, Loss=0.00611]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 14 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.437909126281738


Epoch 15: 100%|██████████| 3166/3166 [01:24<00:00, 37.53it/s, Loss=0.00741]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 15 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.441375732421875


Epoch 16: 100%|██████████| 3166/3166 [01:24<00:00, 37.53it/s, Loss=0.00647]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 16 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.43500280380249


Epoch 17: 100%|██████████| 3166/3166 [01:24<00:00, 37.41it/s, Loss=0.00647]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 17 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.3971967697143555


Epoch 18: 100%|██████████| 3166/3166 [01:24<00:00, 37.36it/s, Loss=0.0072] 


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 18 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.396803379058838


Epoch 19: 100%|██████████| 3166/3166 [01:24<00:00, 37.37it/s, Loss=0.00699]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 19 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.410480499267578


Epoch 20: 100%|██████████| 3166/3166 [01:24<00:00, 37.43it/s, Loss=0.00708]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 20 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.435699939727783


Epoch 21: 100%|██████████| 3166/3166 [01:41<00:00, 31.07it/s, Loss=0.00641]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 21 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.4040374755859375


Epoch 22: 100%|██████████| 3166/3166 [02:06<00:00, 24.94it/s, Loss=0.00563]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 22 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.396872520446777


Epoch 23: 100%|██████████| 3166/3166 [02:06<00:00, 25.05it/s, Loss=0.0082] 


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 23 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.403104305267334


Epoch 24: 100%|██████████| 3166/3166 [02:07<00:00, 24.84it/s, Loss=0.00801]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 24 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.3895745277404785


Epoch 25: 100%|██████████| 3166/3166 [02:31<00:00, 20.87it/s, Loss=0.00546]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 25 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.351120948791504


Epoch 26: 100%|██████████| 3166/3166 [03:02<00:00, 17.36it/s, Loss=0.00673]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 26 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.33709192276001


Epoch 27: 100%|██████████| 3166/3166 [03:23<00:00, 15.53it/s, Loss=0.00826]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 27 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.363386154174805


Epoch 28: 100%|██████████| 3166/3166 [03:38<00:00, 14.52it/s, Loss=0.00843]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 28 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.36587381362915


Epoch 29: 100%|██████████| 3166/3166 [03:43<00:00, 14.15it/s, Loss=0.00596]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 29 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.353678226470947


Epoch 30: 100%|██████████| 3166/3166 [03:42<00:00, 14.20it/s, Loss=0.00655]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 30 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.35421895980835


Epoch 31: 100%|██████████| 3166/3166 [03:39<00:00, 14.42it/s, Loss=0.00628]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 31 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.353683948516846


Epoch 32: 100%|██████████| 3166/3166 [03:41<00:00, 14.31it/s, Loss=0.0054] 


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 32 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.335412979125977


Epoch 33: 100%|██████████| 3166/3166 [03:42<00:00, 14.22it/s, Loss=0.00669]


BER Extrait : 0.44
ok
Detection score: 0.4889 (0=watermarked, 1=non-watermarked)
Fin Epoch 33 | Loss: 0.0068 | BER: 0.44 | Detection: 0.4889, err_wat: 6.336960792541504


Epoch 34:   0%|          | 0/3166 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# 6. Visualisation des resultats
import matplotlib.pyplot as plt

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))

ax1.plot(stats["loss"])
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Distillation Loss")

ax2.plot(stats["ber"])
ax2.axhline(y=0.5, color='r', linestyle='--', label='Random (0.5)')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("BER")
ax2.set_title("BER Evolution During Distillation")
ax2.legend()

ax3.plot(stats["detection"])
ax3.axhline(y=0.5, color='r', linestyle='--', label='Threshold')
ax3.set_xlabel("Epoch")
ax3.set_ylabel("Detection Score")
ax3.set_title("Detection Score (0=watermarked, 1=clean)")
ax3.legend()

plt.tight_layout()
plt.show()