In [1]:
import random

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt
import numpy as np

In [2]:
from attack_splitnn.attack.model_inversion.fsha import FSHA
from attack_splitnn.attack.model_inversion.fshamnist import Resnet, Decoder, Discriminator, Pilot
from attack_splitnn.utils import DataSet

In [3]:
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)

X_train = np.array(trainset.data)
y_train = np.array(trainset.targets)

X_test = np.array(testset.data)
y_test = np.array(testset.targets)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

#public_idx = random.sample(list(np.where(y_train == 3)[0]) + list(np.where(y_train == 7)[0]), k=15)
#private_idx = random.sample(list(np.where(y_test == 3)[0]) + list(np.where(y_test == 7)[0]), k=15)
public_idx = random.sample(list(range(y_train.shape[0])), k=15)
private_idx = random.sample(list(range(y_test.shape[0])), k=15)

X_public = X_train[public_idx]
y_public = y_train[public_idx]
X_private = X_test[private_idx]
y_private = y_test[private_idx]

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.1307, 0.3081)])

publicset = DataSet(X_public, y_public, transform=transform)
publicloader = torch.utils.data.DataLoader(publicset, batch_size=64, shuffle=True, num_workers=1)

privateset = DataSet(X_private, y_private, transform=transform)
privateloader = torch.utils.data.DataLoader(privateset, batch_size=64, shuffle=True, num_workers=1)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [5]:
f = Resnet(level=4)
tilde_f = Pilot(level=4)
decoder = Decoder(input_dim=256, level=4)
D = Discriminator(level=4)

optimizers = [optim.Adam, optim.Adam, optim.Adam]

In [6]:
fsha = FSHA(privateloader, publicloader,
            f, tilde_f, D, decoder,
            optimizers, wgan=True, gradient_penalty=500)

In [7]:
fsha.train(epochs=10000, verbose=100)

f_loss:0.09375400841236115 tilde_f_loss:1.0684651136398315 D_loss:-0.04273427277803421 loss_c:1.1504851579666138
f_loss:11.475993156433105 tilde_f_loss:0.9742318987846375 D_loss:-9.458589553833008 loss_c:1.1381759643554688
f_loss:551244.0 tilde_f_loss:0.6492138504981995 D_loss:-464594.625 loss_c:0.9608574509620667
f_loss:2652885248.0 tilde_f_loss:0.5386390686035156 D_loss:-2112669696.0 loss_c:0.9722080826759338
f_loss:660334837760.0 tilde_f_loss:0.4792214334011078 D_loss:-523507302400.0 loss_c:0.9387363791465759
f_loss:49951811829760.0 tilde_f_loss:0.44034668803215027 D_loss:-46003289849856.0 loss_c:0.8705729842185974
f_loss:1384855027843072.0 tilde_f_loss:0.411140501499176 D_loss:-1318214852149248.0 loss_c:0.8826361894607544
f_loss:1.966884568170496e+16 tilde_f_loss:0.3884895443916321 D_loss:-1.887277349339136e+16 loss_c:0.8765168190002441
f_loss:1.7837263750273434e+17 tilde_f_loss:0.3723907768726349 D_loss:-1.7251179385008947e+17 loss_c:0.8939623236656189
f_loss:1.318674669200474e+18

RuntimeError: DataLoader worker (pid(s) 28580) exited unexpectedly

In [8]:
for (x_private, label_private), (x_public, label_public) in\
                    zip(privateloader, publicloader):
    break;

x_recovered, control = fsha.attack(x_private)

RuntimeError: DataLoader worker (pid(s) 11944) exited unexpectedly

In [None]:
for i in range(1,5): 
    plt.subplot(1,5,i) # 横並びに表示するためのおまじない
    plt.imshow(x_private[i-1].reshape(28,28),cmap='gray_r')
plt.title("x_private")
plt.show()

for i in range(1,5): 
    plt.subplot(1,5,i) # 横並びに表示するためのおまじない
    plt.imshow(x_recovered[i-1].reshape(28,28),cmap='gray_r')
plt.title("x_recovered")
plt.show()

for i in range(1,5): 
    plt.subplot(1,5,i) # 横並びに表示するためのおまじない
    plt.imshow(control[i-1].reshape(28,28),cmap='gray_r')
plt.title("control")
plt.show()