In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
class CIFAR10_CNN(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(8)

        self.fc1 = nn.Linear(8 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):

        x = torch.relu(F.max_pool2d(self.batchnorm1(self.conv1(x)), kernel_size=2, stride=2))
        x = torch.relu(F.max_pool2d(self.batchnorm2(self.conv2(x)), kernel_size=2, stride=2))

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [8]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2470, 0.2435, 0.2616)),
])

data_dir = "./data"
train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_test)
test_ds  = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0)

In [9]:
device = 'cpu'

model = CIFAR10_CNN()
model.to(device)

optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

epochs = 5

for i in range(1, epochs + 1):

    model.train()
    running_train_loss = 0.0

    for images, labels in train_loader:

        images = images.to(device)
        labels = labels.to(device)

        preds = model(images)

        optimizer.zero_grad()
        loss = loss_fn(preds, labels)
        loss.backward()
        optimizer.step()

        running_train_loss += loss

    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:

            images = images.to(device)
            labels = labels.to(device)

            preds = model(images)
            loss = loss_fn(preds, labels)

            running_val_loss += loss

    print(f"{i}: {running_val_loss=}, {running_train_loss=}")

1: running_val_loss=tensor(224.8092), running_train_loss=tensor(1285.4264, grad_fn=<AddBackward0>)
2: running_val_loss=tensor(218.5477), running_train_loss=tensor(1042.9135, grad_fn=<AddBackward0>)
3: running_val_loss=tensor(279.1239), running_train_loss=tensor(944.3522, grad_fn=<AddBackward0>)
4: running_val_loss=tensor(186.5254), running_train_loss=tensor(882.8271, grad_fn=<AddBackward0>)
5: running_val_loss=tensor(261.0594), running_train_loss=tensor(847.1922, grad_fn=<AddBackward0>)


In [10]:
torch.save(model.state_dict(), './models/cifar_model.pt')

In [31]:
model.fc1.weight.view(1, -1).shape

torch.Size([1, 131072])

In [78]:
class WatermarkExtractor(nn.Module):

    def __init__(self, n_params, n_bits):
        super().__init__()
        self.X = nn.Parameter(torch.randn(n_params, n_bits) * 0.1)

    def forward(self, layer_weight: torch.Tensor):

        flattened = layer_weight.view(1, -1)
        probs = torch.sigmoid(torch.matmul(flattened, self.X))
        return probs

In [79]:
def watermark_loss(extractor: WatermarkExtractor, weights: torch.Tensor, target_bits: torch.Tensor):

    probs = extractor(weights)
    assert probs.shape == target_bits.shape

    target_bits = target_bits.float()

    return F.binary_cross_entropy(probs, target_bits)


In [80]:
target_bits = torch.tensor([[1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1]])
target_bits.shape

torch.Size([1, 16])

In [84]:
r, c = model.fc1.weight.shape
wm = WatermarkExtractor(r*c, 16)

In [None]:

fc1_weights = model.fc1.weight

extractor_optimzer = optim.SGD(wm.parameters(), lr=0.1)

for i in range(100):
    
    loss = watermark_loss(wm, fc1_weights, target_bits)
    
    print(f"{i}: {loss}")

    extractor_optimzer.zero_grad()
    loss.backward()
    extractor_optimzer.step()





0: 0.8869643211364746
1: 0.7162444591522217
2: 0.5824374556541443
3: 0.4786510169506073
4: 0.39871537685394287
5: 0.33732837438583374
6: 0.29002583026885986
7: 0.2531982660293579
8: 0.22408445179462433
9: 0.20066756010055542
10: 0.18150900304317474
11: 0.16558396816253662
12: 0.15215860307216644
13: 0.14069890975952148
14: 0.13080903887748718
15: 0.12219193577766418
16: 0.11461876332759857
17: 0.10791292786598206
18: 0.10193528234958649
19: 0.09657395631074905
20: 0.09173943847417831
21: 0.0873580127954483
22: 0.08336956799030304
23: 0.07972396165132523
24: 0.07637917250394821
25: 0.07329951971769333
26: 0.07045510411262512
27: 0.06782010197639465
28: 0.06537248194217682
29: 0.06309293210506439
30: 0.060964807868003845
31: 0.05897374451160431
32: 0.057107165455818176
33: 0.05535336211323738
34: 0.05370289087295532
35: 0.05214671790599823
36: 0.05067726597189903
37: 0.04928745701909065
38: 0.04797092080116272
39: 0.046721864491701126
40: 0.04553564265370369
41: 0.04440752789378166
42: 0

In [97]:
model2 = CIFAR10_CNN()
fc1_model2_weights = model2.fc1.weight

res = wm(fc1_model2_weights)
print(target_bits)
print(res)
(res >= 0.5).int() == target_bits

tensor([[1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1]])
tensor([[0.2509, 0.4852, 0.5434, 0.7013, 0.3283, 0.8512, 0.2311, 0.7508, 0.7840,
         0.6704, 0.5247, 0.5230, 0.3145, 0.7994, 0.4796, 0.4615]],
       grad_fn=<SigmoidBackward0>)


tensor([[False, False,  True,  True,  True, False, False,  True, False,  True,
         False, False, False,  True, False, False]])