## Import libraries and dataset

In [None]:
import numpy as np
import importlib.util
import matplotlib.pyplot as plt
import pickle

from simpleinfotheory import entropy

# from dp_laplace_mechanism import laplace_mechanism
 
# specify the module that needs to be
# imported relative to the path of the
# module
latent_load_module = importlib.util.spec_from_file_location("load_latent",
                                                            "/home/sjay9734/diff_encoder/face_privacy/face-privacy-diffae/utils/load_latent.py")
latent_load = importlib.util.module_from_spec(latent_load_module)
latent_load_module.loader.exec_module(latent_load)

X_train, y_train, X_test, y_test = latent_load.pre_process_celebA(TRAINING_AMOUNT = 0.9)

laplace_mechanism_loader = importlib.util.spec_from_file_location("load_latent",
                                                                  "/home/sjay9734/diff_encoder/face_privacy/face-privacy-diffae/privacy_mechanisms/dp_laplace_mechanism.py")
laplace = importlib.util.module_from_spec(laplace_mechanism_loader)
laplace_mechanism_loader.loader.exec_module(laplace)

## Train inference model (MLP Regression)

In [None]:
# training dataset

LATENT_SIZE = 512
MAX_VARIANCE = 7

def create_dataset():
    i = 0
    train_X_list = []
    train_y_list = []

    for i in range(LATENT_SIZE):
        train_y_list.append([])
    # print((train_y_list))
    i = 0
    while i < (60000):
        
        if np.random.randint(2, size=1) == 1:
            train_X_list.append(X_train[i])
            for axis in range(LATENT_SIZE):
                # print(X_train[i, axis], X_train[i][axis])
                train_y_list[axis].append(float(X_train[i, axis]))
            i += 1
        else:
            random_index = np.random.randint(len(X_train), size=1)
            gaussian_noise = np.random.normal(loc=0.0, 
                                            scale=np.random.randint(MAX_VARIANCE, size=1), size=LATENT_SIZE)
            train_X_list.append(np.reshape(X_train[random_index] + gaussian_noise, (LATENT_SIZE)))
            for axis in range(LATENT_SIZE):
                # print(train_y_list[axis], axis, np.shape(X_train), (X_train[random_index, axis]))
                train_y_list[axis].append(float(X_train[random_index, axis]))
            
    augmented_train_X = np.array(train_X_list)
    augmented_train_y = np.array(train_y_list)

    return augmented_train_X, augmented_train_y

In [None]:
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
model_list = []
input_data, target_data_ = create_dataset()

print(np.shape(input_data), np.shape(target_data_))

for i in range(LATENT_SIZE):
    # print(len(target_data_[i]))
    target_data = target_data_[1]
    
    model = MLPRegressor(hidden_layer_sizes=(64, 32), activation='relu', solver='adam', max_iter=10000)

    # Train the model on the training data
    model.fit(input_data, target_data)

    # Make predictions on the test data
    # y_pred = model.predict(X_test)
    model_list.append(model)
    print("Axis ", i)

### Save trained model

In [None]:
# for i,model_ in enumerate(model_list):
#     with open(f'mlp_models_new/mlp_regression_model_64_32_200000_{i}.pkl', 'wb') as file:
#         pickle.dump(model_, file)

### Load saved model

In [None]:
loaded_model_list = []
for i in range(LATENT_SIZE):
    with open(f'mlp_models_new/mlp_regression_model_64_32_200000_{i}.pkl', 'rb') as file:
        loaded_model_list.append(pickle.load(file))

model_list = loaded_model_list

## Inference 

In [None]:
def guess_original_latent(perturbed_latent, filtered_list = []):
    guessed_latent = np.array(perturbed_latent)
    if len(filtered_list) > 0:
        for i in filtered_list:
            guessed_latent[i] = model_list[i].predict(np.reshape(perturbed_latent, (1,-2)))
    else:
        for i in range(LATENT_SIZE):
            guessed_latent[i] = model_list[i].predict(np.reshape(perturbed_latent, (1,-2)))
    
    return (guessed_latent+perturbed_latent) * 0.5

## DP image generation

In [None]:
sensitivity_arr = []

for i in range(512):
    sensitivity_arr.append(laplace.local_sensitivity_and_min_max(X_train[:,i], 0.1))

laplace_construct = laplace.laplace_mechanism(1)

In [None]:
# DP with laplace

def dp_latent(latent, eps):
    perturbed_latent = np.zeros(512)
    for i in range(512):
        perturbed_latent[i] = laplace_construct.gen_random_output(latent[i], eps, 
                                                                  sensitivity_arr[i][0], 
                                                                  sensitivity_arr[i][1], 
                                                                  sensitivity_arr[i][2])
    
    return perturbed_latent

### DP test

In [None]:
EPS = 10 # 0.1 -> mse > 12
latent_original = 0
for i in range(10):
    latent_original = X_train[i]
    perturbed_latent = dp_latent(X_train[i], EPS)
    guess_latent = guess_original_latent(perturbed_latent)

    perturbed_l2 = np.linalg.norm(perturbed_latent - X_test[i])
    guessed_l2 = np.linalg.norm(guess_latent - X_test[i])

    print("perturbed_l2 ", perturbed_l2, " guessed_l2 ", guessed_l2)


## Load Diffusion encoder

In [None]:
from templates import *
from templates_cls import *
from experiment_classifier import ClsModel

In [None]:
device = 'cuda:1'
conf = ffhq256_autoenc()
# print(conf.name)
model = LitModel(conf)
state = torch.load(f'/home/sjay9734/diff_encoder/face_privacy/face-privacy-diffae/checkpoints/{conf.name}/last.ckpt', map_location='cpu')
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device);

cls_conf = ffhq256_autoenc_cls()
cls_model = ClsModel(cls_conf)
state = torch.load(f'/home/sjay9734/diff_encoder/face_privacy/face-privacy-diffae/checkpoints/{cls_conf.name}/last.ckpt',
                    map_location='cpu')
print('latent step:', state['global_step'])
cls_model.load_state_dict(state['state_dict'], strict=False);
cls_model.to(device);

### Load images

In [None]:
IMAGE_PATH = "imgs_align"

data = ImageDataset(IMAGE_PATH, image_size=conf.img_size, exts=['jpg', 'JPG', 'png'], do_augment=False) # celebA_hq/image


### Test loaded image

In [None]:
img_index = 0

batch = data[img_index]['img'][None]
cond = model.encode(data[img_index]['img'][None].to(device))
xT = model.encode_stochastic(batch.to(device), cond, T=250)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ori = (batch + 1) / 2
ax[0].imshow(ori[0].permute(1, 2, 0).cpu())
ax[1].imshow(xT[0].permute(1, 2, 0).cpu())
plt.imsave("perturbed_imgs/1.png", ori[0].permute(1, 2, 0).cpu().numpy())