### This script is for processing predictions from attacks to further analyze the results.

In [9]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import pickle
import csv
from tqdm import tqdm
import torch


# add ../.. to the path (MIAE)
import sys
sys.path.append('../../')
from miae.utils.dataset_utils import dataset_split
from experiment.models import get_model



In [None]:
"""select which task to perform"""
task = 2

### Task 1: Show training and testing accuracy for all target we have

In [10]:
import csv

data_path = '/data/public/comp_mia_data/repeat_exp_set'
runs = [0, 1, 2, 3]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy(model, data, device):
    model.eval()
    correct = 0
    total = 0
    model.to(device)
    data_loader = torch.utils.data.DataLoader(data, batch_size=128, shuffle=False)
    with torch.inference_mode():
        for images, labels in tqdm(data_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    top1_accuracy = correct / total
    _, top3_predicted = torch.topk(outputs.data, 3, dim=1)
    top3_correct = torch.sum(top3_predicted == labels.unsqueeze(1)).item()
    top3_accuracy = top3_correct / total
    return top1_accuracy, top3_accuracy


def process_accuracy(arch, dataset, runs):
    num_classes = 10 if dataset == "cifar10" else 100
    input_size = 32
    print (f"Number of classes: {num_classes}")
    if arch == "resnet56":
        target_model = get_model("resnet56", num_classes=num_classes, input_size=input_size)
    elif arch == "vgg16":
        target_model = get_model("vgg16", num_classes=num_classes, input_size=input_size)
    elif arch == "mobilenet":
        target_model = get_model("mobilenet", num_classes=num_classes, input_size=input_size)
    elif arch == "wrn32_4":
        target_model = get_model("wrn32_4", num_classes=num_classes, input_size=input_size)

    train_accuracies = []
    train_accuracies_top3 = []
    test_accuracies = []
    test_accuracies_top3 = []

    for run in runs:  
        target_path = os.path.join(data_path, f"miae_experiment_aug_more_target_data_{run}/target")
        target_model_path = f"{target_path}/target_models/{dataset}/{arch}/target_model_{arch}{dataset}.pkl"
        target_train_data_path = f"{target_path}/{dataset}/target_trainset.pkl"
        target_test_data_path = f"{target_path}/{dataset}/target_testset.pkl"

        target_model.load_state_dict(torch.load(target_model_path))
        with open(target_train_data_path, 'rb') as f:
            target_train_data = pickle.load(f)
        with open(target_test_data_path, 'rb') as f:
            target_test_data = pickle.load(f)

        train_acc_ret = accuracy(target_model, target_train_data, device)
        test_acc_ret = accuracy(target_model, target_test_data, device)

        train_accuracies.append(train_acc_ret[0])
        train_accuracies_top3.append(train_acc_ret[1])
        test_accuracies.append(test_acc_ret[0])
        test_accuracies_top3.append(test_acc_ret[1])

    avg_train_accuracy = np.mean(train_accuracies)
    std_train_accuracy = np.std(train_accuracies)
    avg_train_accuracy_top3 = np.mean(train_accuracies_top3)
    std_train_accuracy_top3 = np.std(train_accuracies_top3)
    avg_test_accuracy = np.mean(test_accuracies)
    std_test_accuracy = np.std(test_accuracies)
    avg_test_accuracy_top3 = np.mean(test_accuracies_top3)
    std_test_accuracy_top3 = np.std(test_accuracies_top3)
    generalization_gap = avg_train_accuracy - avg_test_accuracy
    generalization_gap_std = std_train_accuracy - std_test_accuracy


    # average accuracy, std
    print(f"Average train accuracy: {avg_train_accuracy*100:.4f}% ± {std_train_accuracy*100:.4f}%")
    print(f"Average test accuracy: {avg_test_accuracy*100:.4f}% ± {std_test_accuracy*100:.4f}%")
    print(f"Generalization gap: {generalization_gap*100:.4f}% ± {generalization_gap_std*100:.4f}%")

    return avg_train_accuracy, std_train_accuracy, avg_train_accuracy_top3, std_train_accuracy_top3, avg_test_accuracy, std_test_accuracy, avg_test_accuracy_top3, std_test_accuracy_top3, generalization_gap, generalization_gap_std

arch_list = ["resnet56", "vgg16", "mobilenet", "wrn32_4"]
dataset_list = ["cifar10", "cifar100"]

header = ['Architecture', 'Dataset', 'Avg Train Accuracy', 'Std Train Accuracy', 'Avg Test Accuracy', 'Std Test Accuracy', 'Generalization Gap', 'Generalization Gap Std']
if task == 1:
    with open(save_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(header)

    for arch in arch_list:
        for dataset in dataset_list:
            print(f"Processing {arch} on {dataset}")
            avg_train_accuracy, std_train_accuracy, avg_train_accuracy_top3, std_train_accuracy_top3, avg_test_accuracy, std_test_accuracy, avg_test_accuracy_top3, std_test_accuracy_top3, generalization_gap, generalization_gap_std = process_accuracy(arch, dataset, runs)
            
            # save to csv
            save_path = f"{data_path}/target_training_statsd.csv"
            with open(save_path, mode='a') as f:
                writer = csv.writer(f)
                writer.writerow([arch, dataset, avg_train_accuracy, std_train_accuracy, avg_test_accuracy, std_test_accuracy, generalization_gap, generalization_gap_std])
    print(f"csv saved to {save_path}/miae_experiment_aug_more_target_data/target_training_statsd.csv")

Processing resnet56 on cifar10
Number of classes: 10


100%|██████████| 118/118 [00:03<00:00, 32.21it/s]
100%|██████████| 118/118 [00:03<00:00, 34.87it/s]
100%|██████████| 118/118 [00:03<00:00, 34.70it/s]
100%|██████████| 118/118 [00:03<00:00, 30.77it/s]
100%|██████████| 118/118 [00:03<00:00, 37.13it/s]
100%|██████████| 118/118 [00:03<00:00, 35.01it/s]
100%|██████████| 118/118 [00:03<00:00, 31.96it/s]
100%|██████████| 118/118 [00:03<00:00, 31.19it/s]


Average train accuracy: 92.3983% ± 5.5817%
Average test accuracy: 80.4217% ± 1.3478%
Generalization gap: 11.9767% ± 4.2339%
Processing resnet56 on cifar100
Number of classes: 100


100%|██████████| 118/118 [00:03<00:00, 30.57it/s]
100%|██████████| 118/118 [00:03<00:00, 30.80it/s]
100%|██████████| 118/118 [00:03<00:00, 31.46it/s]
100%|██████████| 118/118 [00:03<00:00, 33.43it/s]
100%|██████████| 118/118 [00:03<00:00, 35.64it/s]
100%|██████████| 118/118 [00:03<00:00, 33.27it/s]
100%|██████████| 118/118 [00:03<00:00, 36.75it/s]
100%|██████████| 118/118 [00:03<00:00, 36.41it/s]


Average train accuracy: 96.2133% ± 0.3325%
Average test accuracy: 46.8133% ± 0.9640%
Generalization gap: 49.4000% ± -0.6315%
Processing vgg16 on cifar10
Number of classes: 10


100%|██████████| 118/118 [00:02<00:00, 42.69it/s]
100%|██████████| 118/118 [00:02<00:00, 43.05it/s]
100%|██████████| 118/118 [00:02<00:00, 42.20it/s]
100%|██████████| 118/118 [00:02<00:00, 42.70it/s]
100%|██████████| 118/118 [00:02<00:00, 42.68it/s]
100%|██████████| 118/118 [00:02<00:00, 42.35it/s]
100%|██████████| 118/118 [00:02<00:00, 42.81it/s]
100%|██████████| 118/118 [00:02<00:00, 42.95it/s]


Average train accuracy: 99.7833% ± 0.0256%
Average test accuracy: 83.4717% ± 0.1850%
Generalization gap: 16.3117% ± -0.1594%
Processing vgg16 on cifar100
Number of classes: 100


100%|██████████| 118/118 [00:02<00:00, 42.27it/s]
100%|██████████| 118/118 [00:02<00:00, 42.04it/s]
100%|██████████| 118/118 [00:02<00:00, 43.03it/s]
100%|██████████| 118/118 [00:02<00:00, 42.96it/s]
100%|██████████| 118/118 [00:02<00:00, 41.97it/s]
100%|██████████| 118/118 [00:02<00:00, 42.06it/s]
100%|██████████| 118/118 [00:02<00:00, 43.19it/s]
100%|██████████| 118/118 [00:02<00:00, 41.95it/s]


Average train accuracy: 99.9383% ± 0.0029%
Average test accuracy: 51.3933% ± 0.2450%
Generalization gap: 48.5450% ± -0.2421%
Processing mobilenet on cifar10
Number of classes: 10


100%|██████████| 118/118 [00:03<00:00, 31.72it/s]
100%|██████████| 118/118 [00:03<00:00, 31.90it/s]
100%|██████████| 118/118 [00:03<00:00, 31.68it/s]
100%|██████████| 118/118 [00:03<00:00, 33.53it/s]
100%|██████████| 118/118 [00:03<00:00, 32.26it/s]
100%|██████████| 118/118 [00:03<00:00, 34.13it/s]
100%|██████████| 118/118 [00:03<00:00, 36.02it/s]
100%|██████████| 118/118 [00:03<00:00, 35.56it/s]


Average train accuracy: 95.3050% ± 0.2145%
Average test accuracy: 72.0000% ± 1.3491%
Generalization gap: 23.3050% ± -1.1346%
Processing mobilenet on cifar100
Number of classes: 100


100%|██████████| 118/118 [00:03<00:00, 36.16it/s]
100%|██████████| 118/118 [00:03<00:00, 38.01it/s]
100%|██████████| 118/118 [00:03<00:00, 36.25it/s]
100%|██████████| 118/118 [00:03<00:00, 33.26it/s]
100%|██████████| 118/118 [00:03<00:00, 34.23it/s]
100%|██████████| 118/118 [00:03<00:00, 34.19it/s]
100%|██████████| 118/118 [00:03<00:00, 33.09it/s]
100%|██████████| 118/118 [00:03<00:00, 33.20it/s]


Average train accuracy: 99.9450% ± 0.0247%
Average test accuracy: 36.0250% ± 0.8876%
Generalization gap: 63.9200% ± -0.8630%
Processing wrn32_4 on cifar10
Number of classes: 10


100%|██████████| 118/118 [00:05<00:00, 19.99it/s]
100%|██████████| 118/118 [00:05<00:00, 21.47it/s]
100%|██████████| 118/118 [00:05<00:00, 20.51it/s]
100%|██████████| 118/118 [00:05<00:00, 20.89it/s]
100%|██████████| 118/118 [00:05<00:00, 20.39it/s]
100%|██████████| 118/118 [00:05<00:00, 20.32it/s]
100%|██████████| 118/118 [00:05<00:00, 21.50it/s]
100%|██████████| 118/118 [00:05<00:00, 21.65it/s]


Average train accuracy: 85.2450% ± 2.4878%
Average test accuracy: 75.3733% ± 2.1915%
Generalization gap: 9.8717% ± 0.2963%
Processing wrn32_4 on cifar100
Number of classes: 100


100%|██████████| 118/118 [00:05<00:00, 20.94it/s]
100%|██████████| 118/118 [00:05<00:00, 20.60it/s]
100%|██████████| 118/118 [00:05<00:00, 20.50it/s]
100%|██████████| 118/118 [00:05<00:00, 20.44it/s]
100%|██████████| 118/118 [00:05<00:00, 21.41it/s]
100%|██████████| 118/118 [00:05<00:00, 20.20it/s]
100%|██████████| 118/118 [00:05<00:00, 20.73it/s]
100%|██████████| 118/118 [00:05<00:00, 19.85it/s]

Average train accuracy: 69.7783% ± 2.1496%
Average test accuracy: 39.8017% ± 0.8820%
Generalization gap: 29.9767% ± 1.2676%
csv saved to /data/public/comp_mia_data/repeat_exp_set/target_training_statsd.csv/miae_experiment_aug_more_target_data/target_training_statsd.csv





### loss distribution

In [11]:
target_path = '/data/public/comp_mia_data/miae_experiment_aug_more_target_data'
target_data_path = ''