In [None]:
import os
import yaml
import optuna
import json
import numpy as np
import pickle
import torch

from src.train_models import create_shadow_models_parallel
from src.utils import print_yaml, calculate_logits_and_inmask
from src.models.resnet18_model import ResNet18
from src.dataset_handler import processDataset, loadDataset

In [None]:
# -------------------------------- #
#  Load target model and metadata  #
# -------------------------------- #
target_folder = "resnet-cinic10-c51d329813"
target_path = os.path.join("target", target_folder)

# Target Pickle Metadata .json
metadata_pkl_path = os.path.join(target_path, "model_metadata.pkl")
with open(metadata_pkl_path, "rb") as f:
    metadata_pkl = pickle.load(f)
    
# Target Metadata .json
metadata_path = os.path.join(target_path, "metadata.json")
with open(metadata_path, "r") as f:
    metadata = json.load(f)
print("All metadata loaded")

In [None]:
# ------------------- #
#   Load Train yaml   #
# ------------------- #
config = None
with open("./train.yaml") as file:
    train_config = yaml.safe_load(file)
    
print("\n-------------- Train config --------------")
print_yaml(train_config)

# Update the train config with target metadata (learning rate, batch size, etc)
train_config['train']['epochs'] = metadata['train']['epochs']
train_config['train']['batch_size'] = metadata['train']['batch_size']
train_config['train']['learning_rate'] = metadata['train']['learning_rate']
train_config['train']['momentum'] = metadata['train']['momentum']
train_config['train']['t_max'] = metadata['train']['t_max']
train_config['train']['drop_rate'] = metadata['train']['drop_rate']
train_config['train']['model'] = metadata['train']['model']
train_config['train']['optimizer'] = metadata['train']['optimizer']

print("\n-------------- Updated train config --------------")
print_yaml(train_config)

In [None]:
#-------------------#
#  Prepare dataset  #
#-------------------#
data_cfg = train_config['data']

dataset_name = data_cfg["dataset"]
root = data_cfg.get("root", data_cfg.get("data_dir"))
dataset_path = os.path.join(root, dataset_name + ".pkl")
print(f"dataset path: {dataset_path}")

trainset, testset, full_dataset = loadDataset(data_cfg)

# Will split the dataset to use the same in indices as the baseline target model
train_dataset, test_dataset, train_indices, test_indices = processDataset(data_cfg, trainset, testset, dataset=full_dataset)

# Retrieve the targets
full_dataset = train_dataset.dataset
labels = full_dataset.targets
print(f"Length of dataset targets/labels: {len(labels)}")
print(f"First 10 targets/labels: {labels[:10]}")

In [None]:
# --------------------------- #
#   Train the Shadow models   #
# --------------------------- #
gpu_ids = [0, 1, 2, 3, 4, 5, 6]
results = create_shadow_models_parallel(train_config, 256, gpu_ids, full_dataset, target_folder)

In [None]:
# ------------------------------------------------------------------------------------------------- #
#   STANDALONE Shadow Model, Metadata  and Dataset Loader used for calculating Logits and in_mask   #
# ------------------------------------------------------------------------------------------------- #

# ---------------------------- #
#   Load Dataset from Pickle   #
# ---------------------------- #
# HAVE TO GET DATASET FROM TARGET METADATA IF TO AUTOMATE
dataset_name = "cifar10"
print(f"Dataset used by target model: {dataset_name}")  

reload_dataset = True
if reload_dataset:
    data_path = "data"
    dataset_pkl_path = os.path.join(data_path, dataset_name + ".pkl")    
    with open(dataset_pkl_path, "rb") as f:
        dataset = pickle.load(f)
        # Wrap dataset if needed
        
    if not isinstance(dataset, CifarInputHandler.UserDataset):
        data_tensor, target_tensor = dataset
        dataset = CifarInputHandler.UserDataset(data_tensor, target_tensor)

# ------------------------------------------------------- #
#   Load Shadow Models and Calculate Logits and in_mask   #
# ------------------------------------------------------- #
# Path to the raw shadow models
shadow_models_path = "shadow_models/attack_objects/shadow_model"

# Path to processed shadow models
sm_path = os.path.join("processed_shadow_models", target_folder)
os.makedirs(sm_path, exist_ok=True)

reload_shadow_models = True
if reload_shadow_models:
    i = 0
    while True:
        model_pkl = os.path.join(shadow_models_path, f"shadow_model_{i}.pkl")
        metadata_pkl = os.path.join(shadow_models_path, f"metadata_{i}.pkl")

        if not os.path.exists(model_pkl) or not os.path.exists(metadata_pkl):
            break  # stop when no more models

        print(f"Loading shadow model {i}")

        # Load model weights
        state_dict = torch.load(model_pkl, map_location="cpu")

        # AVAILABLE PARAM CHECKS: model_class, online, init_params["num_classes""]
        # Reinstantiate model
        num_classes = 10  # adapt if CIFAR-100
        model = ResNet18(num_classes=num_classes)
        model.load_state_dict(state_dict)
        model.eval()

        # Load metadata
        with open(metadata_pkl, "rb") as f:
            metadata = pickle.load(f)
        
        calculate_logits_and_inmask(dataset, model, metadata, sm_path, idx=i)
        
        # Clean up
        del metadata
        del model
        torch.cuda.empty_cache()
        
        i += 1
        
print("\nAll shadow model logits computed and saved.")