In [None]:
from membership_inference import DataSet, ShadowModel, AttackerModel

import numpy as np
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [None]:
# トレーニングデータをダウンロード
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
# テストデータをダウンロード
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)

X_train = np.array(trainset.data)
y_train = np.array(trainset.targets)

X_test = np.array(testset.data)
y_test = np.array(testset.targets)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [None]:
victim_idx = random.sample(range(X_train.shape[0]), k=1000)
attack_idx = random.sample(range(X_test.shape[0]), k=2000)
shadow_idx = attack_idx[:1000]
eval_idx = attack_idx[1000:]

X_victim = X_train[victim_idx]
y_victim = y_train[victim_idx]

X_shadow = X_test[shadow_idx]
y_shadow = y_test[shadow_idx]

X_eval = X_test[eval_idx]
y_eval = y_test[eval_idx]

print(X_victim.shape, y_victim.shape)
print(X_shadow.shape, y_shadow.shape)
print(X_eval.shape, y_eval.shape)

In [None]:
# ToTensor：画像のグレースケール化（RGBの0~255を0~1の範囲に正規化）、Normalize：Z値化（RGBの平均と標準偏差を0.5で決め打ちして正規化）
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])

victimset = DataSet(X_victim, y_victim, transform=transform)
victimloader = torch.utils.data.DataLoader(victimset, batch_size=4, shuffle=True, num_workers=2)

valset = DataSet(X_eval, y_eval, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=4, shuffle=True, num_workers=2)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3) # 28x28x32 -> 26x26x32
        self.conv2 = nn.Conv2d(32, 64, 3) # 26x26x64 -> 24x24x64 
        self.pool = nn.MaxPool2d(2, 2) # 24x24x64 -> 12x12x64
        self.dropout1 = nn.Dropout2d()
        self.fc1 = nn.Linear(12 * 12 * 64, 128)
        self.dropout2 = nn.Dropout2d()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        #x = self.dropout1(x)
        x = x.view(-1, 12 * 12 * 64)
        x = F.relu(self.fc1(x))
        #x = self.dropout2(x)
        x = self.fc2(x)
        #x = F.softmax(x, dim=1)
        return x
    
victim_net = Net()

In [None]:
# 交差エントロピー
criterion = nn.CrossEntropyLoss()
# 確率的勾配降下法
optimizer = optim.SGD(victim_net.parameters(), lr=0.005, momentum=0.9)

for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(victimloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = victim_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

    test_preds = []
    test_label = []
    with torch.no_grad():
        for data in valloader:
            inputs, labels = data
            outputs = victim_net(inputs)
            test_preds.append(outputs)
            test_label.append(labels)  
        test_preds = torch.cat(test_preds)
        test_label = torch.cat(test_label)  

    print(accuracy_score(np.array(torch.argmax(test_preds, axis=1)), np.array(test_label)))

print('Finished Training')


in_preds = []
in_label = []
with torch.no_grad():
        for data in victimloader:
            inputs, labels = data
            outputs = victim_net(inputs)
            in_preds.append(outputs)
            in_label.append(labels)  
        in_preds = torch.cat(in_preds)
        in_label = torch.cat(in_label)  
print(accuracy_score(np.array(torch.argmax(in_preds, axis=1)),
                     np.array(in_label)))

out_preds = []
out_label = []
with torch.no_grad():
        for data in valloader:
            inputs, labels = data
            outputs = victim_net(inputs)
            out_preds.append(outputs)
            out_label.append(labels)  
        out_preds = torch.cat(out_preds)
        out_label = torch.cat(out_label)  
print(accuracy_score(np.array(torch.argmax(out_preds, axis=1)),
                     np.array(out_label)))

# Shadow model

In [None]:
# CNNを実装する
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3) # 28x28x32 -> 26x26x32
        self.conv2 = nn.Conv2d(32, 64, 3) # 26x26x64 -> 24x24x64 
        self.pool = nn.MaxPool2d(2, 2) # 24x24x64 -> 12x12x64
        self.dropout1 = nn.Dropout2d()
        self.fc1 = nn.Linear(12 * 12 * 64, 128)
        self.dropout2 = nn.Dropout2d()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = x.view(-1, 12 * 12 * 64)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        #x = F.softmax(x, dim=1)
        return x

In [None]:
sm = ShadowModel([Net(), Net(),Net(), Net(), Net(), Net()], 500, shadow_transform=transform)
y_shadow = np.array(y_shadow)
result = sm.fit_transform(X_shadow, y_shadow, num_itr=10)

# Attack model

In [None]:
from sklearn.svm import SVC
models = [SVC() for i in range(len(result.keys()))]
am = AttackerModel(models)
am.fit(result)

attack_pred_in = am.predict(in_preds, in_label)
attack_pred_out = am.predict(out_preds, out_label)

In [None]:
print("accuracy is ", (sum(attack_pred_in) + 1000 - sum(attack_pred_out)) / 2000)