In [None]:
import os
import sys
sys.path.append("..")

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import utils as utils
import CST as CST
from metrics import recall_m, precision_m, f1_m, auc_m

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

n_classes = len(np.unique(y_train))

x_train = utils.normalize_image(x_train)  # images must be normalized and centralized in 0 for the distortions to work
x_test = utils.normalize_image(x_test)

y_train = tf.keras.utils.to_categorical(y_train, num_classes = n_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes = n_classes)

print(x_train.shape)
print(y_train.shape)

### Train with CST

In [None]:
tile_size = 32

alpha = 1

dist_params = {
    "contrast": {"lower": 0.8, "upper": 1.2},
    "color": {"factor": [20,0,20]},
    "blur": {"kernel_size": 1, "sigma": 3.},  # kernel size is 'kernel_size * 2 + 1'
    "brightness": {"max_delta":0.3}
}

model = tf.keras.Sequential([
    tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(32,32,3)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(n_classes, activation='softmax')
])


cst = CST.ContrastiveStabilityTraining(
    model=model,
    tile_size=tile_size,
    dist_params=dist_params,
    alpha=alpha
)
print(cst)

In [None]:
opt = tf.keras.optimizers.Adam(lr=1e-4, amsgrad=True)
# opt = tf.keras.optimizers.SGD(lr=1e-4)
metrics = ["categorical_crossentropy", recall_m, precision_m, f1_m, auc_m]
loss = tf.keras.losses.categorical_crossentropy

cst.compile_cst(optimizer=opt, metrics=metrics, loss=loss)

In [None]:
# train parameters
save_all_epochs = True
model_save_path = "../models"
model_name = "aaa_cst_cifar_10"
save_metrics = True
epochs = 1

cst.train_cst(
    x=x_test,
    y=y_test,
    validation_data=(x_test,y_test),
    save_all_epochs=save_all_epochs,
    model_save_path=model_save_path,
    model_name=model_name,
    save_metrics=save_metrics,
    epochs=epochs,
    class_weight=None
)

### Train without CST

In [None]:
tile_size = 32

alpha = 0

dist_params = {
    "contrast": {"lower": 0.8, "upper": 1.2},
    "color": {"factor": [20,0,20]},
    "blur": {"kernel_size": 1, "sigma": 3.},  # kernel size is 'kernel_size * 2 + 1'
    "brightness": {"max_delta":0.3}
}

model = tf.keras.Sequential([
    tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(32,32,3)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(n_classes, activation='softmax')
])


cst = CST.ContrastiveStabilityTraining(
    model=model,
    tile_size=tile_size,
    dist_params=dist_params,
    alpha=alpha
)
print(cst)

In [None]:
# opt = tf.keras.optimizers.Adam(lr=1e-4, amsgrad=True)
# opt = tf.keras.optimizers.SGD(lr=1e-4)

# metrics = ["categorical_crossentropy", recall_m, precision_m, f1_m, auc_m]

params = {
    "class_name": "Adam",
    "config":
    {
        "lr": 0.0001,
        "amsgrad": True,
        "epsilon": 0.1
    }
}
opt = tf.keras.optimizers.get(params)

loss = tf.keras.losses.categorical_crossentropy

cst.compile_cst(optimizer=opt, metrics=metrics, loss=loss)

In [None]:
# train parameters
save_all_epochs = True
model_save_path = "../models"
model_name = "bbb_no_cst_cifar_10"
save_metrics = True
epochs = 1

cst.train_cst(
    x=x_test,
    y=y_test,
    validation_data=(x_test,y_test),
    save_all_epochs=save_all_epochs,
    model_save_path=model_save_path,
    model_name=model_name,
    save_metrics=save_metrics,
    epochs=epochs,
    class_weight=None
)

#### load cifar-c
TODO: https://stackoverflow.com/questions/4256107/running-bash-commands-in-python

In [None]:
# mkdir -p ./data/cifar
# curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar
# tar -xvf CIFAR-100-C.tar -C data/cifar/

In [None]:
cifar_c_path = "../data/CIFAR-10-C"
img_corr = np.load(os.path.join(cifar_c_path, "brightness.npy"))
labels = np.load(os.path.join(cifar_c_path, "labels.npy"))

print(img_corr.shape)