In [None]:
import os
import yaml
import optuna
import json
import torch
import numpy as np
import pickle

from src.save_load import buildTargetMetadata, saveTarget, saveTargetSignals
from src.train_models import trainTargetModel, trainFbDTargetModel
from src.visualize_model import VisualizeModel
from src.utils import print_yaml, calculate_logits_and_inmask, rescale_logits
from src.models.resnet18_model import ResNet18
from src.models.wideresnet28_model import WideResNet
from src.cifar_handler import CifarInputHandler
from src.dataset_handler import processDataset, loadDataset, get_dataloaders

from LeakPro.leakpro.attacks.mia_attacks.rmia import rmia_get_gtlprobs

from torch.utils.data import DataLoader

from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

In [None]:
# ---------------------- #
#   Load training yaml   #
# ---------------------- #
config = None
with open("./train.yaml") as file:
    config = yaml.safe_load(file)
    
train_cfg = config["train"]
data_cfg = config["data"]
    
print(f"Initial training config: {train_cfg}")

In [None]:
# ------------------------------ #
#  Load study + best trial info  #
# ------------------------------ #
study_folder = "cinic10-resnet-baseline-6d2602e58b"
study_path = os.path.join("study", study_folder)

journal_storage = False
if journal_storage:
    # Journal + storage
    journal_path = os.path.join(study_path, "journal.log")
    storage = JournalStorage(JournalFileBackend(journal_path))

    # Load study
    study_name = metadata["study"]["study_name"]
    study = optuna.load_study(storage=storage, study_name=study_name)
else:
    db_path = os.path.join("study", "baseline_study.db")
    storage = f"sqlite:///{db_path}"
    study_name = study_folder
    study = optuna.load_study(study_name=study_name, storage=storage)
    print(f"Quantity of trials: {len(study.trials)}")

# Metadata
metadata_path = os.path.join(study_path, "metadata.json")
with open(metadata_path, "r") as f:
    study_metadata = json.load(f)

# Best trial
best_trial = study.best_trial
print(f"Best trial value: {best_trial.value}")
print("Best trial parameters:")
for k, v in best_trial.params.items():
    print(f"  {k}: {v}")

# ------------------------------ #
#  Update training configuration #
# ------------------------------ #

# Overwrite study-optimized params
train_cfg["batch_size"] = best_trial.params["batch_size"]
train_cfg["learning_rate"] = best_trial.params["lr"]
train_cfg["momentum"] = best_trial.params["momentum"]
train_cfg["weight_decay"] = best_trial.params["weight_decay"]
train_cfg["t_max"] = best_trial.params.get("T_max", None)  # fallback default
train_cfg["model"] = study_metadata["study"]["model"]   # Model architecture from metadata
train_cfg["drop_rate"] = best_trial.params.get("drop_rate", None)  # fallback default
train_cfg["optimizer"] = best_trial.params.get("optimizer", "SGD")

print("Modified train_cfg with optimal hyperparameters")

config["train"] = train_cfg

# ------------------- #
#  Save the metadata  #
# ------------------- #
train_metadata = buildTargetMetadata(train_cfg, data_cfg)
hash_id, save_dir = saveTarget(train_metadata)
print(f"Saved training metadata with {hash_id} at {save_dir}")
print("-------------------- New config --------------------")
print_yaml(config)

In [None]:
# ------------------- #
#   Prepare dataset   #
# ------------------- #
print("-------------------- Data_cfg --------------------")
print_yaml(data_cfg)

batch_size = train_cfg["batch_size"]

trainset, testset, full_dataset = loadDataset(data_cfg)
train_dataset, test_dataset, train_indices, test_indices = processDataset(data_cfg, trainset, testset, dataset=full_dataset)
# Perpare loaders
train_loader, test_loader = get_dataloaders(batch_size, train_dataset, test_dataset)

In [None]:
# ------------------------------- #
#   Train Baseline Target model   #
# ------------------------------- #
train_result, test_result = trainTargetModel(config, train_loader, test_loader, train_indices, test_indices, save_dir)
if(train_result != None and test_result != None):
    VisualizeModel().visualize(train_result, test_result)

In [None]:
# -------------------------------- #
#   Load saved model and dataset   #
# -------------------------------- #

# Load the target
reload_target = True
if reload_target:

    target = "resnet-cinic10-c51d329813"
    target_path = os.path.join("target", target)

    metadata_path = os.path.join(target_path, "metadata.json")
    # Load metadata.json
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
        
    # Load and Recreate the model
    target_model_pkl_path = os.path.join(target_path, "target_model.pkl")
    
    state_dict = torch.load(target_model_pkl_path, map_location="cpu")
    if metadata["data"]["dataset"] == "cifar10" or metadata["data"]["dataset"] == "cinic10":
        num_classes = 10 # or 100 in case of cifar 100
    else:
        num_classes = 100
        
    if metadata["train"]["model"] == "resnet":
        model = ResNet18(num_classes=num_classes)
    elif metadata["train"]["model"] == "wideresnet":
        model = WideResNet(depth=28, num_classes=num_classes, widen_factor=10)
   
    model.load_state_dict(state_dict)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Load metadata
    metadata_model_pkl_path = os.path.join(target_path, "model_metadata.pkl")       
    with open(metadata_model_pkl_path, "rb") as f:
        metadata_pkl = pickle.load(f)

# ---------------------------- #
#   Load Dataset from Pickle   #
# ---------------------------- #
reload_dataset = True
if reload_dataset:

    dataset_name = metadata["data"]["dataset"] + ".pkl"
    data_path = "data"

    dataset_pkl_path = os.path.join(data_path, dataset_name)       
    with open(dataset_pkl_path, "rb") as f:
        dataset = pickle.load(f)
    
    # Wrap the full dataset in UserDataset if needed
    if not isinstance(dataset, CifarInputHandler.UserDataset):
        data_tensor, target_tensor = dataset
        dataset = CifarInputHandler.UserDataset(data_tensor, target_tensor)
        
    labels = dataset.targets
    
train_indices = metadata_pkl.train_indices
test_indices = metadata_pkl.test_indices
print(f"First 10 labels: {labels[:10]}, length: {len(labels)}")
print(f"Train_indices length: {len(train_indices)}, test_indices length: {len(test_indices)}")


In [None]:
# ----------------------------------------- #
#   Calculate and save Logits and in_mask   #
# ----------------------------------------- #
logits, in_mask = calculate_logits_and_inmask(dataset, model, metadata_pkl, target_path, idx=None, save=False)
if labels is not None:
    resc_logits = rescale_logits(logits, labels)
    gtl_probs = rmia_get_gtlprobs(logits, labels)
saveTargetSignals(logits, in_mask, target_path, resc_logits, gtl_probs)