In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from tqdm import tqdm
from utils import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug
import wandb
import copy
import random
from reparam_module import ReparamModule
from torchvision.utils import save_image

In [2]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [3]:
parser = argparse.ArgumentParser(description='Parameter Processing')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')

parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet')

parser.add_argument('--model', type=str, default='ConvNet', help='model')

parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')

parser.add_argument('--eval_mode', type=str, default='S',
                    help='eval_mode, check utils.py for more info')

parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on')

parser.add_argument('--eval_it', type=int, default=100, help='how often to evaluate')

parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform')

parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images')
parser.add_argument('--lr_lr', type=float, default=1e-05, help='learning rate for updating... learning rate')
parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate')

parser.add_argument('--lr_init', type=float, default=0.01, help='how to init lr (alpha)')

parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')

parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
                    help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')

parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
                    help='whether to use differentiable Siamese augmentation.')

parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
                    help='differentiable Siamese augmentation strategy')

parser.add_argument('--data_path', type=str, default='data', help='dataset path')
parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')

parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')
parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at')

parser.add_argument('--zca', action='store_true', help="do ZCA whitening")

parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM")

parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation')

parser.add_argument('--texture', action='store_true', help="will distill textures instead")
parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas')
parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration')


parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)')
parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)')

parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')




_StoreTrueAction(option_strings=['--force_save'], dest='force_save', nargs=0, const=True, default=False, type=None, choices=None, help='this will save images for 50ipc', metavar=None)

In [4]:
import sys
#sys.argv="examples/run_expt.py --dataset iwildcam --algorithm DANN --batch_size 8 --root_dir E:\Python_project\AdaCowd+Meta\wilds-main\wilds-main\data --n_groups_per_batch 1 --distinct_group True --frac 0.01 --train_loader group --uniform_over_groups True".split()
sys.argv="distill2.py --dataset=CIFAR10 --model=ResNet18 --ipc=100 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path=D:/research_2022/mtt-distillation-main/mtt-distillation-main/buffer --data_path=D:/research_2022/mtt-distillation-main/mtt-distillation-main/data".split()

# Print number of arguments passed in
print (f'Number of arguments: {len(sys.argv)}')

# Loop through the arguments and print them
for arg in range(len(sys.argv)):
  print(f' Argument {arg} is: {sys.argv[arg]}')
save_dir = 'D:/research_2022/PyTorch-GAN-master/PyTorch-GAN-master/implementations/cgan/data/model/'

Number of arguments: 13
 Argument 0 is: distill2.py
 Argument 1 is: --dataset=CIFAR10
 Argument 2 is: --model=ResNet18
 Argument 3 is: --ipc=100
 Argument 4 is: --syn_steps=20
 Argument 5 is: --expert_epochs=3
 Argument 6 is: --max_start_epoch=20
 Argument 7 is: --zca
 Argument 8 is: --lr_img=1000
 Argument 9 is: --lr_lr=1e-05
 Argument 10 is: --lr_teacher=0.01
 Argument 11 is: --buffer_path=D:/research_2022/mtt-distillation-main/mtt-distillation-main/buffer
 Argument 12 is: --data_path=D:/research_2022/mtt-distillation-main/mtt-distillation-main/data


In [7]:
import torchvision.transforms as transforms
import torch
import io

x_train=torch.load('D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/images_best.pt',map_location=lambda storage, loc: storage.cuda(0))
y_train=torch.load('D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/labels_best.pt',map_location=lambda storage, loc: storage.cuda(0))
save_dir='D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/'
with open('D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/Mix/labels_best.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())
    print(buffer)
torch.load(buffer)

<_io.BytesIO object at 0x000002711FBE0900>


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

In [None]:
# Load the parameters from json file
args = parser.parse_args()


model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
print("mode before",args.eval_mode)
model_eval=model_eval_pool[0]
print(model_eval_pool)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
syn_lr = torch.tensor(args.lr_teacher).to(args.device)

args.eval_mode='train'
print("MODE",args.eval_mode)
if args.dsa:
    # args.epoch_eval_train = 1000
    args.dc_aug_param = None

args.dsa_param = ParamDiffAug()

dsa_params = args.dsa_param

# best_acc = {m: 0 for m in model_eval_pool}

# best_std = {m: 0 for m in model_eval_pool}

print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s'%(args.model, model_eval))
if args.dsa:
    print('DSA augmentation strategy: \n', args.dsa_strategy)
#     print('DSA augmentation parameters: \n', args.dsa_param.__dict__)
# else:
#     print('DC augmentation parameters: \n', args.dc_aug_param)
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)

accs_test = []
accs_train = []
""""""
for it_eval in range(args.num_eval):
    net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model

    eval_labs = y_train
    with torch.no_grad():
        image_save = x_train
    image_syn_eval, label_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification

    args.lr_net = syn_lr.item()
    _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, x_train, y_train, testloader, args, texture=args.texture)
    accs_test.append(acc_test)
    accs_train.append(acc_train)
    



accs_test = np.array(accs_test)
accs_train = np.array(accs_train)
acc_test_mean = np.mean(accs_test)
acc_test_std = np.std(accs_test)
torch.save(net_eval.state_dict(), os.path.join(save_dir, 'net_conv_%acc_test_mean.pth' % acc_test_mean))
# if acc_test_mean > best_acc[model_eval]:
#     best_acc[model_eval] = acc_test_mean
#     best_std[model_eval] = acc_test_std
#     save_this_it = True
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), model_eval, acc_test_mean, acc_test_std))
""""""

mode before S
['ResNet18']
MODE train
-------------------------
Evaluation
model_train = ResNet18, model_eval = ResNet18
DSA augmentation strategy: 
 color_crop_cutout_flip_scale_rotate
Files already downloaded and verified
Files already downloaded and verified
Train ZCA


100%|█████████████████████████████████████████████████████████████████████████| 50000/50000 [00:04<00:00, 11701.23it/s]


Test ZCA


100%|█████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 10929.06it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1001/1001 [08:41<00:00,  1.92it/s]


[2022-10-24 17:07:14] Evaluate_00: epoch = 1000 train time = 521 s train loss = 0.002155 train acc = 1.0000, test acc = 0.4702


  3%|██▊                                                                             | 35/1001 [00:28<10:35,  1.52it/s]

In [None]:
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

# test()

In [None]:




import torchvision.transforms as transforms
# train_mean = np.array([125.307, 122.950, 113.865])
# train_std = np.array([62.993, 62.089, 66.705])
# test_mean = np.array([126.025, 123.708, 114.854])
# test_std = np.array([62.896, 61.937, 66.706])

train_mean = np.array([0.4914, 0.4822, 0.4465])
train_std = np.array([0.2023, 0.1994, 0.2010])
test_mean = np.array([0.4914, 0.4822, 0.4465])
test_std = np.array([0.2023, 0.1994, 0.2010])

train_mean /= 255
train_std /= 255
# test_mean /= 255
# test_std /= 255

normalize = transforms.Normalize(train_mean, train_std)  
train_transformer = transforms.Compose([
    
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])

import torch.utils.data as data

x_train=torch.load('D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/images_best.pt',map_location=lambda storage, loc: storage.cuda(0))
y_train=torch.load('D:/research_2022/mtt-distillation-main/mtt-distillation-main/logged_files/CIFAR10/project/labels_best.pt',map_location=lambda storage, loc: storage.cuda(0))


train_x =x_train
#train_x = torch.stack([(x) for x in x_gen])
train_y = y_train
print(train_x.shape)
print(train_y.shape)


# train_x=train_x[0:60000,:,:,:]
print(train_x.shape)
print(train_y.shape)
full_indices = np.arange(len(train_x))

np.random.shuffle(full_indices)
tensor_x = train_x[full_indices]
tensor_y = train_y[full_indices]

full_indices = np.arange(len(train_y))
np.random.shuffle(full_indices)
tensor_x = train_x[full_indices]
tensor_y = train_y[full_indices]

trainset = data.TensorDataset(tensor_x, tensor_y)  # create your datset
trainloader = data.DataLoader(trainset, batch_size=25, shuffle=True, num_workers=0)

In [None]:
args.lr_net = syn_lr.item()
lr = float(args.lr_net)
Epoch = int(args.epoch_eval_train)
lr_schedule = [Epoch//2+1]
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

criterion = nn.CrossEntropyLoss().to(args.device)

In [None]:
import time
import sys

'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)


device='cuda'
net = net_eval
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True


# _, term_width = os.popen('stty size', 'r').read().split()
# term_width = int(term_width)
term_width=400

TOTAL_BAR_LENGTH = 40.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f



criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.00002,
#                       momentum=0.9, weight_decay=5e-4)
#optimizer = optim.Adam(net.parameters(), lr=2e-7)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | TRAIN Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = DiffAugment(inputs, args.dsa_strategy, param=args.dsa_param)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | TEST Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    best_acc=-10
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

start_epoch=1
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)
    scheduler.step()
accs_train = []

# PRIVACY

In [None]:
# teacher_model = alexnet.AlexNet(num_classes=10)
# teacher_checkpoint = 'experiments/base_cnn/epoch399'
# utilse.load_checkpoint(teacher_checkpoint, teacher_model)

# %% Inference Attack HZ Class


class InferenceAttack_HZ(nn.Module):
    def __init__(self, num_classes):
        self.num_classes = num_classes

        super(InferenceAttack_HZ, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(self.num_classes, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
        )

        self.labels = nn.Sequential(
            nn.Linear(num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
        )

        self.combine = nn.Sequential(
            nn.Linear(64 * 2, 256),

            nn.ReLU(),
            nn.Linear(256, 128),

            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

        for key in self.state_dict():
            print(f'\t {key}')
            if key.split('.')[-1] == 'weight':
                nn.init.normal_(self.state_dict()[key], std=0.01)

            elif key.split('.')[-1] == 'bias':
                self.state_dict()[key][...] = 0

        self.output = nn.Sigmoid()

    def forward(self, x, labels):

        out_x = self.features(x)
        out_l = self.labels(labels)

        is_member = self.combine(torch.cat((out_x, out_l), 1))

        return self.output(is_member)


# %% Status Func

def report_str(batch_idx, data_time, batch_time, losses, top1, top5):
    batch = f'({batch_idx:4d})'
    time = f'Data: {data_time:.2f}s | Batch: {batch_time:.2f}s'
    loss_ac1 = f'Loss: {losses:.3f} | Top1: {top1 * 100:.2f}%'

    res = f'{batch} {time} || {loss_ac1}'

    if top5 is None:
        return res
    else:
        return res + f' | Top5: {top5 * 100:.2f}%'



In [None]:
from sklearn.metrics import top_k_accuracy_score
from utilz import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
# %% privacy_train
# torch.Size([128, 3, 32, 32]) torch.Size([128]) torch.Size([128, 3, 32, 32]) torch.Size([128])
# PRED torch.Size([256, 10])
# infer_in torch.Size([256])
def privacy_train(trainloader, model, inference_model, criterion, optimizer, use_cuda, num_batchs):
    num_classes=10
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    mtop1_a = AverageMeter()
    mtop5_a = AverageMeter()

    inference_model.train()
    model.eval()
    # switch to evaluate mode

    end = time.time()
    first_id = -1
    for batch_idx, ((tr_input, tr_target), (te_input, te_target)) in trainloader:
        # measure data loading time
        if first_id == -1:
            first_id = batch_idx

        data_time.update(time.time() - end)
        
        #print(tr_input.shape, tr_target.shape,te_input.shape, te_target.shape)

        if use_cuda:
            tr_input = tr_input.cuda()
            te_input = te_input.cuda()
            tr_target = tr_target.cuda()
            te_target = te_target.cuda()

        v_tr_input = torch.autograd.Variable(tr_input)
        v_te_input = torch.autograd.Variable(te_input)
        v_tr_target = torch.autograd.Variable(tr_target)
        v_te_target = torch.autograd.Variable(te_target)

        # compute output
        model_input = torch.cat((v_tr_input, v_te_input))

        pred_outputs = model(model_input)
        #print("PRED",pred_outputs.shape)
        #y_hat

        infer_input = torch.cat((v_tr_target, v_te_target))
        #print("infer_in",infer_input.shape)
        #(y_hat)

        # TODO fix
        # mtop1, mtop5 = accuracy(pred_outputs.data, infer_input.data, topk=(1, 5))
        mtop1 = top_k_accuracy_score(y_true=infer_input.data.cpu(), y_score=pred_outputs.data.cpu(),
                                     k=1, labels=range(num_classes))

        mtop5 = top_k_accuracy_score(y_true=infer_input.data.cpu(), y_score=pred_outputs.data.cpu(),
                                     k=5, labels=range(num_classes))

        mtop1_a.update(mtop1, model_input.size(0))
        mtop5_a.update(mtop5, model_input.size(0))

        one_hot_tr = torch.from_numpy((np.zeros((infer_input.size(0), num_classes)) - 1)).cuda().type(torch.float)
        target_one_hot_tr = one_hot_tr.scatter_(1, infer_input.type(torch.int64).view([-1, 1]).data, 1)

        infer_input_one_hot = torch.autograd.Variable(target_one_hot_tr)
        #ONE_hot y_hat

        attack_model_input = pred_outputs  # torch.cat((pred_outputs,infer_input_one_hot),1)
        member_output = inference_model(attack_model_input, infer_input_one_hot)
        #inf_model(y,y_hat)
        #member->?0/1

        is_member_labels = torch.from_numpy(
            np.reshape(
                np.concatenate((np.zeros(v_tr_input.size(0)), np.ones(v_te_input.size(0)))),
                [-1, 1]
            )
        ).cuda()

        v_is_member_labels = torch.autograd.Variable(is_member_labels).type(torch.float)
        #true_labels

        loss = criterion(member_output, v_is_member_labels)

        # measure accuracy and record loss
        prec1 = np.mean((member_output.data.cpu().numpy() > 0.5) == v_is_member_labels.data.cpu().numpy())
        losses.update(loss.data.item(), model_input.size(0))
        top1.update(prec1, model_input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if batch_idx - first_id > num_batchs:
            break

        # plot progress
        if batch_idx % 50 == 0:
            #print("STUCK")
            print( losses.avg, top1.avg)
            #print(report_str(batch_idx, data_time.avg, batch_time.avg, losses.avg, top1.avg, None))

    return losses.avg, top1.avg



In [None]:
import torchvision.transforms as transforms

import torchvision.datasets as datasets

train_mean = np.array([125.307, 122.950, 113.865])
train_std = np.array([62.993, 62.089, 66.705])
test_mean = np.array([126.025, 123.708, 114.854])
test_std = np.array([62.896, 61.937, 66.706])

# Normalize mean std to 0..1 from 0..255
train_mean /= 255
train_std /= 255
test_mean /= 255
test_std /= 255

print(f'Hard code CIFAR10 train/test mean/std for next time')

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(test_mean, test_std),
])

# TODO check loader for trainloader_private
trainset_private =datasets.CIFAR10(root='./data-cifar10', train=True,
        download=True, transform=transform_test)
trainloader_private = torch.utils.data.DataLoader(trainset_private, batch_size=128, shuffle=True)

In [None]:
# inference_model = torch.nn.DataParallel(inferenece_model).cuda()
import torch.optim as optim
import time

model=net_eval
LR = 0.05
EPOCHS = 5
print('\tTotal params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

criterion = nn.CrossEntropyLoss()

criterion_attack = nn.MSELoss()

optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

inference_model = InferenceAttack_HZ(10).cuda()

private_train_criterion = nn.MSELoss()

optimizer_mem = optim.Adam(inference_model.parameters(), lr=0.00001)

best_acc = 0.0
start_epoch = 0

# Train and val
for epoch in range(start_epoch, EPOCHS):
    #adjust_learning_rate(optimizer, epoch)

    print(f'\nEpoch: [{epoch + 1:d} | {EPOCHS:d}] ')

    train_private_enum = enumerate(zip(trainloader_private, testloader))
    privacy_loss, privacy_acc = privacy_train(train_private_enum, model, inference_model, criterion_attack, optimizer_mem, True, 100)
    print(f'Privacy Res: {privacy_acc * 100:.2f}% ')

