In [None]:
from models.generator import build_g

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import pickle

import gc
import os

import tensorflow as tf


from tensorflow import keras

from tensorflow.keras import optimizers, losses

from tensorflow.image import grayscale_to_rgb


from sklearn.metrics import recall_score

from sklearn.metrics import accuracy_score


from sklearn.utils import shuffle

from time import sleep

from IPython import display


# generator variabels

z_dim = 128

initial_shape = (8, 8, 1024)

embed_label_shape = (8, 8, 1)

num_classes = 2


models_name = [
    "vgg16",
    "resnet50v2",
    "mobilenetv2_1.00_128",
    "inception_v3",
    "convnext_tiny",
]

In [None]:
def get_model(pretrained_name, path_to_file):
    model = keras.models.load_model(path_to_file)

    # change trainable property, need to compile model again

    model.get_layer(pretrained_name).trainable = False
    model.compile(

        optimizer=optimizers.Adam(learning_rate=1e-4),

        loss=losses.BinaryCrossentropy(),

        metrics=["accuracy"],
    )


    return model

In [None]:
def get_testset_perf(y_pred_cate, y_true):
    acc = accuracy_score(y_true, y_pred_cate)
    sensitivity = recall_score(y_true, y_pred_cate)
    specificity = recall_score(y_true, y_pred_cate, pos_label=0)

    return acc, sensitivity, specificity

In [None]:
model_dict = {}
model_folder = ""
saved_model = os.listdir(model_folder)
for i, file_name in enumerate(saved_model):
    model_dict[models_name[i]] = os.path.join(model_folder, file_name)

In [None]:
# check trainable
for name, path in model_dict.items():
    print("model: ", name)
    model = get_model(name, path)
    for layer in model.layers:
        print(layer.name, layer.trainable)
    print("-" * 20)

    del model
    gc.collect()
    keras.backend.clear_session()

In [None]:
# check summary
for name, path in model_dict.items():
    display.clear_output(wait=True)
    model = get_model(name, path)
    model.summary()
    sleep(7)

    del model
    gc.collect()
    keras.backend.clear_session()

In [None]:
# check gan
save_objects_folder = ""
file_name = os.path.join(save_objects_folder, "fid_dcgan.pickle")
with open(file_name, "rb") as f:
    fid_array = np.array(pickle.load(f))

mean_fid = np.mean(fid_array, axis=1)
epochs = list(range(5, 201, 5))
print("epoch of min:", epochs[np.argmin(mean_fid)])
print("min FID:", np.min(mean_fid))

In [None]:
weight_folder = ""
weight_file = ""

weight = os.path.join(weight_folder, weight_file)


# -------------

G = build_g()

G.load_weights(weight)

# -------------


NUM_SAMPLES = 16

inputs = tf.random.normal(shape=(NUM_SAMPLES, z_dim))

labels = tf.random.uniform((NUM_SAMPLES, 1), 0, 2, dtype=tf.int32)

predictions = G.predict(x=[inputs, labels], verbose=0)

predictions = predictions * 127.5 + 127.5


fig = plt.figure(figsize=(10, 10))

for i in range(predictions.shape[0]):

    plt.subplot(4, 4, i + 1)

    plt.imshow(predictions[i, :, :, 0], cmap="gray")

    title = "Normal" if str(labels[i].numpy()[0]) == "0" else "Pneumonia"
    plt.title(title)

    plt.axis("off")


plt.savefig("filepath.svg", format="svg", dpi=1200)

del G
gc.collect()

keras.backend.clear_session()

In [None]:
X_test = np.load("")
y_test = np.load("")

In [None]:
num_poisons = 32
decreased_perf_dict = {}
maximum_poisons = 14000

for name, path in model_dict.items():
    decreased_perf_dict[name] = {}
    ACCURACY = []
    SENSITIVITY = []
    SPECIFICITY = []
    for i in range(3):
        display.clear_output(wait=True)
        print(f"get model {name}, round = {i + 1}")
        victim = get_model(name, path)
        count_num_poisons = 0
        accuracy = []
        sensitivity = []
        specificity = []
        while count_num_poisons < maximum_poisons:
            # gen normal lung
            inputs = tf.random.normal((num_poisons, z_dim))
            normal_labels = tf.zeros((num_poisons, 1), dtype=tf.int32)
            normal_lung_images = G([inputs, normal_labels], training=False)
            normal_lung_images = normal_lung_images * 127.5 + 127.5
            normal_lung_images = grayscale_to_rgb(normal_lung_images)

            # gen pneumonia
            inputs = tf.random.normal((num_poisons, z_dim))
            penumonia_labels = tf.ones((num_poisons, 1), dtype=tf.int32)
            penumonia_lung_images = G([inputs, penumonia_labels], training=False)
            penumonia_lung_images = penumonia_lung_images * 127.5 + 127.5
            penumonia_lung_images = grayscale_to_rgb(penumonia_lung_images)

            # poison
            xp = tf.concat([normal_lung_images, penumonia_lung_images], axis=0).numpy()
            yp = tf.concat(
                [tf.ones_like(normal_labels), tf.zeros_like(penumonia_labels)], axis=0
            ).numpy()

            xp, yp = shuffle(xp, yp)

            # re-train
            _ = victim.train_on_batch(x=xp, y=yp)
            #
            y_pred_prob = victim.predict(X_test, verbose=0)
            y_pred_cate = np.where(y_pred_prob < 0.5, 0, 1)
            accu, sens, spec = get_testset_perf(y_pred_cate, y_test)
            accuracy.append(accu)
            sensitivity.append(sens)
            specificity.append(spec)

            count_num_poisons += num_poisons * 2

            print(f"acc = {accu:.5f}, sens = {sens:.5f}, speci = {spec:.5f}")
            print(f"num poisons so far = {count_num_poisons}")
            print("-" * 20)

            del (
                normal_labels,
                normal_lung_images,
                penumonia_labels,
                penumonia_lung_images,
                xp,
                yp,
                y_pred_prob,
                y_pred_cate,
            )
            gc.collect()

        ACCURACY.append(accuracy)
        SENSITIVITY.append(sensitivity)
        SPECIFICITY.append(specificity)

        del accuracy, sensitivity, specificity
        gc.collect()
        keras.backend.clear_session()

    decreased_perf_dict[name]["accuracy"] = ACCURACY
    decreased_perf_dict[name]["sensitivity"] = SENSITIVITY
    decreased_perf_dict[name]["specificity"] = SPECIFICITY

    print(f"finish collection results for {name}!")
    print("start next model!")
    sleep(3)

display.clear_output()
print("done...")

In [None]:
folder = ""
file = ""
file_name = os.path.join(folder, file)

with open(file_name, "wb") as f:

    pickle.dump(decreased_perf_dict, f)

In [None]:
folder = ""
file = "test_set_performance.pickle"

file_name = os.path.join(folder, file)

with open(file_name, "rb") as f:

    perf_dict = pickle.load(f)

In [None]:
folder = ""
file = "decreased_test_set_performance_overall.pickle"

file_name = os.path.join(folder, file)

with open(file_name, "rb") as f:

    decreased_perf_dict = pickle.load(f)

In [None]:
accu = {}
sens = {}
spec = {}
for model_name in decreased_perf_dict.keys():
    accu[model_name] = np.mean(decreased_perf_dict[model_name]["accuracy"], axis=0)
    sens[model_name] = np.mean(decreased_perf_dict[model_name]["sensitivity"], axis=0)
    spec[model_name] = np.mean(decreased_perf_dict[model_name]["specificity"], axis=0)

accu_df = pd.DataFrame(accu)
sens_df = pd.DataFrame(sens)
spec_df = pd.DataFrame(spec)

accu_df["npoison"] = 64 * np.arange(1, accu_df.shape[0] + 1)
sens_df["npoison"] = 64 * np.arange(1, sens_df.shape[0] + 1)
spec_df["npoison"] = 64 * np.arange(1, spec_df.shape[0] + 1)

In [None]:
x_label = "the number of poison samples"

In [None]:
ax = accu_df.plot(figsize=(10, 6), x="npoison", xlabel=x_label, ylabel="accuracy score")

ax.legend(["VGG16", "ResNet50V2", "MobileNetV2", "InceptionV3", "ConvNeXt-Tiny"])

ax.set_ylim(0, 1)

plt.savefig("accuracy.svg", format="svg", dpi=1200, bbox_inches="tight")

plt.show()

In [None]:
ax = sens_df.plot(
    figsize=(10, 6), x="npoison", xlabel=x_label, ylabel="sensitivity score"
)
ax.legend(["VGG16", "ResNet50V2", "MobileNetV2", "InceptionV3", "ConvNeXt-Tiny"])
ax.set_ylim(0, 1)
plt.savefig("sensitivity.svg", format="svg", dpi=1200, bbox_inches="tight")
plt.show()

In [None]:
ax = spec_df.plot(
    figsize=(10, 6), x="npoison", xlabel=x_label, ylabel="specificity score"
)
ax.legend(["VGG16", "ResNet50V2", "MobileNetV2", "InceptionV3", "ConvNeXt-Tiny"])
ax.set_ylim(0, 1)
plt.savefig("specificity.svg", format="svg", dpi=1200, bbox_inches="tight")
plt.show()