In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

from federated_learning.attacks.MemberInference.metrics import *
from federated_learning.attacks.MemberInference.train import *

from federated_learning.model import LeNet_Small_Quant, mlleaks_mlp

import federated_learning

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [2]:
n_epochs = 3
batch_size = 32
lr = 0.01
k = 3

target_net = LeNet_Small_Quant().to(device)
shadow_net = LeNet_Small_Quant().to(device)

In [3]:
(X_train, y_train), (X_test, y_test) = federated_learning.load_cifar10(num_users=1, n_class=10, n_samples=1000, even_split=True)
target_X, target_y = X_train[0], y_train[0]

Files already downloaded and verified
Files already downloaded and verified


In [4]:
(X_train, y_train), (X_test, y_test) = federated_learning.load_cifar10(num_users=1, n_class=10, n_samples=1000, even_split=True)
shadow_X, shadow_y = X_train[0], y_train[0]

Files already downloaded and verified
Files already downloaded and verified


In [5]:
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]


def imshow(img):
    img = img / 255.0
    img = np.moveaxis(img, 0, -1)
    plt.imshow(img)
    plt.show()

In [6]:
total_size = shadow_X.shape[0]
indices = list(range(total_size))
split = int(total_size/2)

train_idx, out_idx = indices[:split], indices[split:]

train_sampler = SubsetRandomSampler(train_idx)
out_sampler = SubsetRandomSampler(out_idx)

shadow_train = shadow_X[:split], shadow_y[:split]
shadow_out = shadow_X[split:], shadow_y[split:]

target_train = target_X[:split], target_y[:split]   
target_out = target_X[split:], target_y[split:]

shadow_train_dataset = torch.utils.data.TensorDataset(torch.tensor(shadow_train[0]).float(), torch.tensor(shadow_train[1]).long())
shadow_out_dataset = torch.utils.data.TensorDataset(torch.tensor(shadow_out[0]).float(), torch.tensor(shadow_out[1]).long())

target_train_dataset = torch.utils.data.TensorDataset(torch.tensor(target_train[0]).float(), torch.tensor(target_train[1]).long())
target_out_dataset = torch.utils.data.TensorDataset(torch.tensor(target_out[0]).float(), torch.tensor(target_out[1]).long())

shadow_train_loader = torch.utils.data.DataLoader(shadow_train_dataset, batch_size=batch_size)    
shadow_out_loader = torch.utils.data.DataLoader(shadow_out_dataset, batch_size=batch_size)

target_train_loader = torch.utils.data.DataLoader(target_train_dataset, batch_size=batch_size)
target_out_loader = torch.utils.data.DataLoader(target_out_dataset, batch_size=batch_size)

test_dataset = torch.utils.data.TensorDataset(torch.tensor(X_test).float(), torch.tensor(y_test).long())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [7]:
shadow_train[1]

array([5, 7, 0, ..., 3, 9, 8])

In [8]:
target_loss = nn.CrossEntropyLoss()
target_optim = optim.Adam(target_net.parameters(), lr=lr)

shadow_loss = nn.CrossEntropyLoss()
shadow_optim = optim.Adam(shadow_net.parameters(), lr=lr)

attack_net = mlleaks_mlp(n_in=k).to(device)
attack_loss = nn.BCEWithLogitsLoss()
attack_optim = optim.Adam(attack_net.parameters(), lr=lr)

In [9]:
train(shadow_net, shadow_train_loader, test_loader, shadow_optim, shadow_loss, n_epochs, classes=classes)

[0/3]
Training:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 100.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 0.0
Accuracy of truck: 0.0
Accuracy: 10.82
Test:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 100.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 0.0
Accuracy of truck: 0.0
Accuracy: 10.0
[1/3]
Training:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 100.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 0.0
Accuracy of truck: 0.0
Accuracy: 10.82
Test:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 100.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 0.0
Accu

(10.82, 10.0)

In [10]:
train_attacker(attack_net, shadow_net, shadow_train_loader, shadow_out_loader, attack_optim, attack_loss, n_epochs=10, k=k)

In [11]:
train(target_net, target_train_loader, test_loader, target_optim, target_loss, n_epochs, classes=classes)

[0/3]
Training:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 0.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 100.0
Accuracy of ship: 0.0
Accuracy of truck: 0.0
Accuracy: 10.0
Test:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.1
Accuracy of deer: 0.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 99.9
Accuracy of ship: 0.0
Accuracy of truck: 0.0
Accuracy: 10.0
[1/3]
Training:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 0.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 100.0
Accuracy of truck: 0.0
Accuracy: 10.48
Test:
Accuracy of airplane: 0.0
Accuracy of automobile: 0.0
Accuracy of bird: 0.0
Accuracy of cat: 0.0
Accuracy of deer: 0.0
Accuracy of dog: 0.0
Accuracy of frog: 0.0
Accuracy of horse: 0.0
Accuracy of ship: 100.0
Accura

(10.48, 10.0)

In [12]:
eval_attack_model(attack_net, target_net, target_train_loader, target_out_loader, k=k)

print("\nPerformance on training set: ")
train_accuracy = eval_target_model(target_net, target_train_loader, classes=None)

print("\nPerformance on test set: ")
test_accuracy = eval_target_model(target_net, test_loader, classes=None)


Performance on training set: 
Accuracy: 10.48

Performance on test set: 
Accuracy: 10.0
