# DAWN: Dynamic Adversarial Watermarking of Neural Networks

Implementation of the work presented in the paper [DAWN: Dynamic Adversarial Watermarking of Neural Networks](https://arxiv.org/pdf/1906.00830.pdf).

## Paper Abstract
Training machine learning (ML) models is expensive in terms of
computational power, amounts of labeled data and human expertise.
Thus, ML models constitute intellectual property (IP) and business
value for their owners. Embedding digital watermarks during model
training allows a model owner to later identify their models in
case of theft or misuse. However, model functionality can also be
stolen via model extraction, where an adversary trains a surrogate
model using results returned from a prediction API of the original
model. Recent work has shown that model extraction is a realistic
threat. Existing watermarking schemes are ineffective against IP
theft via model extraction since it is the adversary who trains the
surrogate model. In this paper, we introduce DAWN (Dynamic
Adversarial Watermarking of Neural Networks), the first approach
to use watermarking to deter model extraction IP theft. Unlike prior
watermarking schemes, DAWN does not impose changes to the
training process but it operates at the prediction API of the protected
model, by dynamically changing the responses for a small subset of
queries (e.g., <0.5%) from API clients. This set is a watermark that
will be embedded in case a client uses its queries to train a surrogate
model. We show that DAWN is resilient against two state-of-the-art
model extraction attacks, effectively watermarking all extracted
surrogate models, allowing model owners to reliably demonstrate
ownership, incurring negligible loss of
prediction accuracy (0.03\%-0.5\%).

## Goals

The authers in the paper follow two different roles:

1. Provide a robust framework againts model extraction attacks.
2. Utilizing watermarking to maintain the model ownership via a robust approach.

### Import packages

In [1]:
import string
import random
import torch
from warnings import simplefilter

from mlmodelwatermarking.marktorch import Trainer
from mlmodelwatermarking.verification import verify
from mlmodelwatermarking import TrainingWMArgs

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

## Model architecture and DB construction

In [2]:
class LeNet(nn.Module):
    """ MNIST model """
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def load_MNIST():
    """ Load MNIST dataset
    Returns:
    trainloader (object): training dataloader
    testloader (object): test dataloader

    """
    transformation = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
    ])
    dataset = torchvision.datasets.MNIST('/tmp/',
                                         train=True,
                                         download=True,
                                         transform=transformation)
    size_split = int(len(dataset) * 0.8)
    trainset, valset = torch.utils.data.random_split(
        dataset, [size_split, len(dataset) - size_split])

    testset = torchvision.datasets.MNIST('/tmp/',
                                         train=False,
                                         download=True,
                                         transform=transformation)

    return trainset, valset, testset

simplefilter(action='ignore', category=UserWarning)

## Main Watermarking

In [3]:
def default_key(length: int):
    elements = string.ascii_uppercase + string.digits
    return ''.join(random.choices(elements, k=length))


"""Testing of watermarking for MNIST model."""

# WATERMARKED
model = LeNet()

trainset, valset, testset = load_MNIST()
model = LeNet()
args = TrainingWMArgs(
        trigger_technique='merrer',
        optimizer='SGD',
        lr=0.01,
        gpu=True,
        epochs=10,
        nbr_classes=10,
        batch_size=64,
        watermark=False)

trainer_clean = Trainer(
                model=model,
                args=args,
                trainset=trainset,
                valset=valset,
                testset=testset)
trainer_clean.train()
original_model = trainer_clean.get_model()

args = TrainingWMArgs(
            nbr_classes=10,
            key_dawn=default_key(255),
            probability_dawn=0.01,
            trigger_technique='dawn',
            metric='accuracy')

trainer = Trainer(
            model=original_model,
            trainset=trainset,
            args=args)

ownership, wm_model = trainer.get_model()
triggerloader = torch.utils.data.DataLoader(
                    ownership['inputs'],
                    batch_size=32,
                    shuffle=True)
results = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for _, data in enumerate(triggerloader):
    inputs = data
    pred = wm_model(inputs.to(device))
    results += list(torch.argmax(pred, 1).cpu().numpy())

verification = verify(
                ownership['labels'],
                results,
                number_labels=args.nbr_classes,
                metric='accuracy',
                dawn=True)

# This should return True
print(f"Model is stolen: {verification['is_stolen']}")

results = []
for _, data in enumerate(triggerloader):
    inputs = data
    pred = original_model(inputs.to(device))
    results += list(torch.argmax(pred, 1).cpu().numpy())

verification = verify(
                ownership['labels'],
                results,
                number_labels=args.nbr_classes,
                metric='accuracy')

# This should return False
print(f"Model is stolen: {verification['is_stolen']}")


INFO:logger:Training
Validation accuracy: 98.3167: 100%|█████████████| 10/10 [02:26<00:00, 14.67s/it]
INFO:logger:Generation of the trigers


Model is stolen: True
Model is stolen: False
