In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0


In [None]:
input_shape = (224,224,3)
num_classes = 7
batch_size = 16
epochs = 50
learning_rate = 1e-4
temperature = 3.0
alpha_min, alpha_max = 0.3, 0.5
beta_min, beta_max = 0.3, 0.5
lambda_min, lambda_max = 0.1, 0.4
train_path = '/code/MyCode/AUG/HAM10000/train_dir'
teacher_probs_path = 'teacher_hybrid_probs.npy'
teacher_lrp_path = 'LRP_hybrid.npy'


In [None]:
datagen = ImageDataGenerator(preprocessing_function=tf.keras.applications.efficientnet.preprocess_input)
train_gen = datagen.flow_from_directory(
    directory=train_path,
    target_size=input_shape[:2],
    batch_size=batch_size,
    shuffle=False
)


In [None]:
teacher_probs = np.load(teacher_probs_path)
teacher_lrp = np.load(teacher_lrp_path)


In [None]:
base = EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape, pooling='avg')
x = base.output
x = Dropout(0.5)(x)
output = Dense(num_classes, activation='softmax')(x)
student = Model(inputs=base.input, outputs=output)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)


In [None]:
def compute_lrp(model, X_batch):
    return np.abs(X_batch).mean(axis=-1, keepdims=True)


In [None]:
steps_per_epoch = len(train_gen)
for epoch in range(epochs):
    alpha = alpha_min + (epoch/epochs)*(alpha_max-alpha_min)
    beta  = beta_min + (epoch/epochs)*(beta_max-beta_min)
    lambda_t = lambda_min + (epoch/epochs)*(lambda_max-lambda_min)
    s = alpha+beta+lambda_t
    alpha, beta, lambda_t = alpha/s, beta/s, lambda_t/s
    for batch_idx in range(steps_per_epoch):
        X_batch, y_batch = train_gen.next()
        start_idx = batch_idx*batch_size
        end_idx = start_idx+X_batch.shape[0]
        teacher_batch = teacher_probs[start_idx:end_idx]
        teacher_lrp_batch = teacher_lrp[start_idx:end_idx]
        with tf.GradientTape() as tape:
            y_pred = student(X_batch, training=True)
            L_CE = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_batch, y_pred))
            y_pred_soft = tf.nn.softmax(y_pred/temperature)
            teacher_soft = tf.nn.softmax(teacher_batch/temperature)
            L_KD = tf.reduce_mean(tf.keras.losses.KLDivergence()(teacher_soft, y_pred_soft))*(temperature**2)
            student_lrp_batch = compute_lrp(student, X_batch)
            L_importance = tf.reduce_mean(tf.abs(student_lrp_batch - teacher_lrp_batch))
            L_total = alpha*L_CE + beta*L_KD + lambda_t*L_importance
        grads = tape.gradient(L_total, student.trainable_variables)
        optimizer.apply_gradients(zip(grads, student.trainable_variables))


In [None]:
student.save("student_EfficientNetB0_KD_LRP.hdf5")


In [None]:
test_path = '/code/MyCode/AUG/HAM10000/test_dir'
test_gen = datagen.flow_from_directory(
    directory=test_path,
    target_size=input_shape[:2],
    batch_size=batch_size,
    shuffle=False
)
steps_test = len(test_gen)


In [None]:
student_preds = []
for _ in range(steps_test):
    X_batch, _ = test_gen.next()
    preds_batch = student.predict(X_batch, verbose=0)
    student_preds.append(preds_batch)
student_preds = np.vstack(student_preds)
np.save("student_predictions.npy", student_preds)


In [None]:
student_lrp_test = []
test_gen.reset()
for _ in range(steps_test):
    X_batch, _ = test_gen.next()
    lrp_batch = compute_lrp(student, X_batch)
    student_lrp_test.append(lrp_batch)
student_lrp_test = np.vstack(student_lrp_test)
np.save("student_LRP.npy", student_lrp_test)


In [None]:
import matplotlib.pyplot as plt

def show_lrp_test(X_batch, student_lrp_batch, idx=0):
    img = X_batch[idx]
    s_lrp = student_lrp_batch[idx].squeeze()
    
    fig, axes = plt.subplots(1,2, figsize=(8,4))
    axes[0].imshow((img - img.min()) / (img.max() - img.min()))
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    axes[1].imshow(s_lrp, cmap='hot')
    axes[1].set_title("Student LRP")
    axes[1].axis('off')
    plt.show()

X_batch, _ = test_gen.next()
show_lrp_test(X_batch, student_lrp_test, idx=0)
