In [None]:
import os
import yaml
import optuna
import json
import numpy as np
import pickle
import torch

from src.save_load import buildAuditMetadata, saveAudit
from src.train_models import trainTargetModel
from src.visualize_model import VisualizeModel
from src.utils import print_yaml, calculate_logits_and_inmask
from src.models.resnet18_model import ResNet18
from src.cifar_handler import CifarInputHandler

In [None]:
# -------------------------------- #
#  Load target model and metadata  #
# -------------------------------- #
target_folder = "resnet18-8bef88e056"
target_path = os.path.join("target", target_folder)

# Target Pickle Metadata .json
metadata_pkl_path = os.path.join(target_path, "model_metadata.pkl")
with open(metadata_pkl_path, "rb") as f:
    metadata_pkl = pickle.load(f)
    
# Target Metadata .json
metadata_path = os.path.join(target_path, "metadata.json")
with open(metadata_path, "r") as f:
    metadata = json.load(f)

In [None]:
# ------------------- #
#   Load Audit yaml   #
# ------------------- #
config = None
with open("./audit.yaml") as file:
    audit_config = yaml.safe_load(file)

print("-------------- Audit config --------------")
print_yaml(audit_config['audit'])
print_yaml(audit_config['target'])

# Update the audit config
audit_config['target']['target_folder'] = target_path

print("\n-------------- Updated audit config --------------")
print_yaml(audit_config['audit'])
print_yaml(audit_config['target'])

# ------------------- #
#   Load Train yaml   #
# ------------------- #
config = None
with open("./train.yaml") as file:
    train_config = yaml.safe_load(file)
    
print("\n-------------- Train config --------------")
print_yaml(train_config)

# Update the train config with target metadata (learning rate, batch size, etc)
train_config['train']['epochs'] = metadata['train']['epochs']
train_config['train']['batch_size'] = metadata['train']['batch_size']
train_config['train']['learning_rate'] = metadata['train']['learning_rate']
train_config['train']['momentum'] = metadata['train']['momentum']

train_config['run']['log_dir'] = target_path

print("\n-------------- Updated train config --------------")
print_yaml(train_config)

In [None]:
from LeakPro.leakpro.schemas import LeakProConfig
from LeakPro.leakpro.attacks.mia_attacks.lira import AttackLiRA
from LeakPro.leakpro.attacks.utils.shadow_model_handler import ShadowModelHandler
from LeakPro.leakpro.input_handler.mia_handler import MIAHandler
from src.cifar_handler import CifarInputHandler
# ----------------- #
#   Setup LeakPro   #
# ----------------- #

# Intizializing
leakpro_configs = LeakProConfig(**audit_config)
print("-------- LeakPro Configs --------")
print_yaml(leakpro_configs)

handler = MIAHandler(leakpro_configs, CifarInputHandler)

configs = handler.configs.audit.attack_list[0]
print("-------- Attack Configs --------")
print_yaml(configs)

attack = AttackLiRA(handler=handler, configs=configs)

In [None]:
# -------------------------------- #
#   Setup Shadow models Training   #
# -------------------------------- #

#Set number of shadow models to train
num_shadow_models = configs["num_shadow_models"]
#Set online flag
online = configs["online"]

attack_data_indices = attack.sample_indices_from_population(include_train_indices = online,
                                                        include_test_indices = online)

training_data_fraction = attack.training_data_fraction

smh = ShadowModelHandler(handler)
smh.epochs = train_config["train"]["epochs"]
smh.batch_size = train_config['train']['batch_size']
smh.learning_rate = train_config['train']['learning_rate']
smh.momentum = train_config['train']['momentum']

In [None]:
# --------------------------- #
#   Train the Shadow models   #
# --------------------------- #
shadow_model_indices = smh.create_shadow_models(num_models = num_shadow_models,
                                                 shadow_population =  attack_data_indices,
                                                 training_fraction = training_data_fraction,
                                                 online = online,
                                                 #verbose = False,
                                                 #incremental = INCREMENTAL, 
                                                 #shuffle_shift = SHUFFLE_SHIFT
                                                        )

In [None]:
# ---------------------- #
#   Load Shadow Models   #
# ---------------------- #
shadow_models, _ = smh.get_shadow_models(shadow_model_indices)

In [None]:
# Get the audit dataset from the attack
audit_dataset = attack.audit_dataset
audit_data_indices = audit_dataset["data"]
true_labels = handler.get_labels(audit_dataset["data"])
print(f"\nTrue labels fetched: {true_labels[:10]}")

In [None]:
# -------------------------------- #
#   Extract Shadow model Signals   #
# -------------------------------- #

# Fetch and rescale shadow model logits
shadow_models_logits = []
for indx in shadow_model_indices:
    shadow_models_logits.append(smh.load_logits(indx=indx))
# Transpose the rescaled_sm_logits as rescale_logits flips them
rescaled_sm_logits = np.array([attack.rescale_logits(x, true_labels) for x in shadow_models_logits]).T

# Get the in indices mask for shadow models
in_indices_masks = ShadowModelHandler(handler).get_in_indices_mask(shadow_model_indices, audit_dataset["data"])#.astype(int)

print("\n--------- Shadow models logits extracted ---------")
print(f"sm in mask:\n {in_indices_masks[:10]}")
print(f"Shadow model logits:\n {rescaled_sm_logits[0]}")

In [None]:
# -------------------------------- #
#   Extract Target model Signals   #
# -------------------------------- #
# Logits
target_logits = smh.load_logits(name="target")
rescaled_target_logits = attack.rescale_logits(target_logits, true_labels)

#In mask
train_indices = metadata_pkl.train_indices
test_indices = metadata_pkl.test_indices
target_audit_data_indices = np.concatenate([train_indices, test_indices])

if(np.array_equal(audit_dataset["data"], target_audit_data_indices)):
    target_in_mask = np.isin(target_audit_data_indices, train_indices)
else:
    print("audit_dataset does not match target_audit_data_indices")

print("\n--------- target model logits & in mask extracted ---------")
print(f"target in mask:\n {target_in_mask[:10]}")
print(f"target logits:\n {rescaled_target_logits[:10]}")
print(f"target logits shape: {rescaled_target_logits.shape}")

In [None]:
# ------------------------------ #
#   Save logits & indices mask   #
# ------------------------------ #
metadata = buildAuditMetadata(train_config, audit_config)

hash_id, save_dir = saveAudit(metadata, rescaled_target_logits, rescaled_sm_logits, in_indices_masks, target_in_mask, audit_data_indices)

In [None]:
# ------------------------------------------------------------------------------------------------- #
#   STANDALONE Shadow Model, Metadata  and Dataset Loader used for calculating Logits and in_mask   #
# ------------------------------------------------------------------------------------------------- #

# ---------------------------- #
#   Load Dataset from Pickle   #
# ---------------------------- #
# HAVE TO GET DATASET FROM TARGET METADATA IF TO AUTOMATE
dataset_name = "cifar10"
print(f"Dataset used by target model: {dataset_name}")  

reload_dataset = True
if reload_dataset:
    data_path = "data"
    dataset_pkl_path = os.path.join(data_path, dataset_name + ".pkl")    
    with open(dataset_pkl_path, "rb") as f:
        dataset = pickle.load(f)
        # Wrap dataset if needed
        
    if not isinstance(dataset, CifarInputHandler.UserDataset):
        data_tensor, target_tensor = dataset
        dataset = CifarInputHandler.UserDataset(data_tensor, target_tensor)

# ------------------------------------------------------- #
#   Load Shadow Models and Calculate Logits and in_mask   #
# ------------------------------------------------------- #
# Path to the raw shadow models
shadow_models_path = "shadow_models/attack_objects/shadow_model"

# Path to processed shadow models
sm_path = os.path.join("processed_shadow_models", target_folder)
os.makedirs(sm_path, exist_ok=True)

reload_shadow_models = True
if reload_shadow_models:
    i = 0
    while True:
        model_pkl = os.path.join(shadow_models_path, f"shadow_model_{i}.pkl")
        metadata_pkl = os.path.join(shadow_models_path, f"metadata_{i}.pkl")

        if not os.path.exists(model_pkl) or not os.path.exists(metadata_pkl):
            break  # stop when no more models

        print(f"Loading shadow model {i}")

        # Load model weights
        state_dict = torch.load(model_pkl, map_location="cpu")

        # AVAILABLE PARAM CHECKS: model_class, online, init_params["num_classes""]
        # Reinstantiate model
        num_classes = 10  # adapt if CIFAR-100
        model = ResNet18(num_classes=num_classes)
        model.load_state_dict(state_dict)
        model.eval()

        # Load metadata
        with open(metadata_pkl, "rb") as f:
            metadata = pickle.load(f)
        
        calculate_logits_and_inmask(dataset, model, metadata, sm_path, idx=i)
        
        # Clean up
        del metadata
        del model
        torch.cuda.empty_cache()
        
        i += 1
        
print("\nAll shadow model logits computed and saved.")