In [1]:
!pip install image_classifiers
#!pip install keras==3.1.1 tensorflow==2.16.1
import numpy as np
import tensorflow as tf
from keras.applications import ResNet50V2
from keras.datasets import cifar100
from keras import Sequential, Input
from keras.layers import Dense, Dropout, RandomFlip, RandomTranslation, RandomRotation,RandomBrightness, RandomContrast, RandomZoom, GlobalAveragePooling2D
from keras.applications.resnet_v2 import preprocess_input
from keras.models import Model
from classification_models.keras import Classifiers
from keras.optimizers import SGD,Adam
from keras.activations import linear
from keras.utils import Progbar
from keras.backend import clear_session
from tensorflow.nn import softmax_cross_entropy_with_logits
import os
import re



2024-05-03 15:30:49.704404: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-03 15:30:49.705328: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-03 15:30:49.712245: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-03 15:30:49.784885: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
n_epoch = 100
batch_size = 64
taux_validation = 0.1
num_classes = 100
n_images = 50000 # Pour l'entrainement, et 10000 pour le test

In [3]:
validation_size = int(n_images * taux_validation)
train_size = n_images - validation_size

In [4]:
model_enseignant = Sequential([
    Input((224,224,3)),
    ResNet50V2(include_top=False, weights='imagenet', pooling="avg"),
    Dropout(0.25),
    Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
    Dropout(0.5),
    Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
])
for layer in model_enseignant.layers:
    if layer.name == "resnet50v2":
        for layer in layer.layers[:]:
          if (re.match("^.*(_3_conv|_bn)$", layer.name)):
            layer.trainable = True
          else:
            layer.trainable = False

In [5]:
def preprocessing(image, label):
    image = tf.image.resize(image, (224, 224))
    label = tf.squeeze(tf.one_hot(label, depth = num_classes), axis = 0)
    return  image, label

augmentation_donnees_keras = Sequential([
    RandomFlip("horizontal"),
    RandomTranslation(0.2,0.2),
    RandomRotation(0.2),
    RandomZoom(0.2),
    RandomContrast(0.2),
    RandomBrightness(0.2,value_range=(0,1))
])

def augmentation_donnees(image, label):
    return augmentation_donnees_keras(image/255.0, training = True)*255.0, label

def preprocess_resnet(image, label):
    return preprocess_input(image), label

def train_val_split(train_dataset, validation_size):
    X_train, y_train = train_dataset
    return (X_train[:train_size,...], y_train[:train_size,...]), (X_train[train_size:,...], y_train[train_size:,...])

def load_cifar_train_val():
    train_dataset, _ = cifar100.load_data()
    
    train_dataset, validation_dataset = train_val_split(train_dataset, validation_size)
    
    validation_dataset = tf.data.Dataset.from_tensor_slices(validation_dataset).map(preprocessing).batch(batch_size).map(preprocess_resnet).cache().prefetch(tf.data.AUTOTUNE)
    train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset).map(preprocessing).cache().repeat().shuffle(train_size).batch(batch_size).map(augmentation_donnees, num_parallel_calls = tf.data.AUTOTUNE).map(preprocess_resnet, num_parallel_calls = tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
    return train_dataset, validation_dataset

In [6]:
def get_modele_logits(modele):
    config = modele.layers[-1].get_config()
    config['activation'] = linear
    config['name'] = 'logits'
    res = Model(inputs=modele.inputs, outputs=[Dense(**config)(modele.layers[-2].output)])
    res.layers[-1].set_weights([x.numpy() for x in modele.layers[-1].weights])
    res.compile(metrics=['accuracy'])
    return res

def get_regularisation(model):
    return [(getattr(layer,reg),layer.kernel) for layer in model.layers for reg in ["kernel_regularizer", "bias_regularizer"] if hasattr(layer, reg) and getattr(layer,reg) is not None]

@tf.function
def perte_regularisation(regularisation):
    return tf.add_n([tf.reduce_sum(reg(kernel)) for reg, kernel in regularisation])

@tf.function
def compte_bons(x,y):
    return tf.reduce_sum(tf.cast(tf.equal(tf.argmax(x, axis = 1), tf.argmax(y, axis = 1)), tf.float32))

@tf.function
def softmax(logits, temp):
    expo = tf.exp(logits / temp)
    return expo / tf.reduce_sum(expo, axis = 1, keepdims=True)

@tf.function
def ce(x, y_logits, temp):
    return softmax_cross_entropy_with_logits(x, y_logits / temp) * temp**2

def init_csv_log(fichier):
    with open(fichier,'w') as file:
        file.write("epoch, accuracy_etu, accuracy_ens, accuracy_dis, val_accuracy_etu, val_accuracy_ens, val_accuracy_dis\n")
def append_csv_log(fichier, epoch, accuracy_etu, accuracy_ens, accuracy_dis, val_accuracy_etu, val_accuracy_ens, val_accuracy_dis):
    with open(fichier,'a') as file:
        file.write(f"{epoch:d},{accuracy_etu:.2f},{accuracy_ens:.2f},{accuracy_dis:.2f},{val_accuracy_etu:.2f},{val_accuracy_ens:.2f},{val_accuracy_dis:.2f}\n")

def forward_backward_pass_impl(train_dataset_iter, enseignant_logit_model, etudiant_logit_model, alpha, temp, optim_enseignant, optim_etudiant, regularisation_etudiant, regularisation_enseignant):
    X_batch, y_batch = next(train_dataset_iter)
    with tf.GradientTape() as tape_ens:
        enseignant_estim_logit = enseignant_logit_model(X_batch, training = True)
        perte_ens = softmax_cross_entropy_with_logits(y_batch,enseignant_estim_logit) + perte_regularisation(regularisation_enseignant)
    enseignant_estim_softmax = softmax(enseignant_estim_logit, temp)
    with tf.GradientTape() as tape_etu:
        etudiant_estim_logit = etudiant_logit_model(X_batch, training = True)
        perte_etu = alpha * softmax_cross_entropy_with_logits(y_batch,etudiant_estim_logit) + (1-alpha) * ce(enseignant_estim_softmax,etudiant_estim_logit, temp) + perte_regularisation(regularisation_etudiant)
    grad_etudiant = tape_etu.gradient(perte_etu, etudiant_logit_model.trainable_variables)
    grad_enseignant = tape_ens.gradient(perte_ens, enseignant_logit_model.trainable_variables)
    optim_etudiant.apply_gradients(zip(grad_etudiant, etudiant_logit_model.trainable_variables))
    optim_enseignant.apply_gradients(zip(grad_enseignant, enseignant_logit_model.trainable_variables))
    return compte_bons(etudiant_estim_logit,y_batch), compte_bons(enseignant_estim_logit,y_batch), compte_bons(etudiant_estim_logit,enseignant_estim_logit)

@tf.function
def val_accuracies(etudiant, enseignant, validation_dataset):
    val_bons_etu, val_bons_ens, val_bons_dis = 0.0, 0.0, 0.0
    for X_batch, y_batch in validation_dataset:
        etu_prev = etudiant(X_batch, training = False)
        ens_prev = enseignant(X_batch, training = False)
        val_bons_etu += compte_bons(etu_prev, y_batch)
        val_bons_ens += compte_bons(ens_prev, y_batch)
        val_bons_dis += compte_bons(etu_prev, ens_prev)
    nb_obs_val = tf.cast((validation_size//batch_size) * batch_size, tf.float32)
    val_accuracy_etu = val_bons_etu / nb_obs_val
    val_accuracy_ens = val_bons_ens / nb_obs_val
    val_accuracy_dis = val_bons_dis / nb_obs_val
    return val_accuracy_etu, val_accuracy_ens, val_accuracy_dis

def distillateur_kl_en_ligne(etudiant, enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch, alpha):
    etudiant_logit_model = get_modele_logits(etudiant)
    enseignant_logit_model = get_modele_logits(enseignant)
    optim_etudiant = SGD(learning_rate=0.001)
    optim_enseignant = Adam(learning_rate=0.001)
    init_csv_log(f"sauvegardes/{nom_modele}_logs.csv")
    print("C'est parti pour la distillation !\n")
    val_accuracy_etu_max, val_accuracy_ens_max, val_accuracy_dis_max  = 0, 0, 0
    regularisation_etudiant = get_regularisation(etudiant)
    regularisation_enseignant = get_regularisation(enseignant)
    forward_backward_pass = tf.function(forward_backward_pass_impl)
    train_dataset_iter = iter(train_dataset)
    # La tf.function ne peut être que locale car son graphe dépend d'étudiant_logit_model et sinon Tensorflow renvoie une erreur à deux applications successives
    for epoch in range(n_epoch):
        print(f"Époque {epoch + 1} / {n_epoch}")
        n_batch = train_size//batch_size
        barre_progression = Progbar(n_batch, stateful_metrics = ["acc (etu, train)", "acc (ens, train)", "acc (dis, train)"])
        bons_epoque_etu, bons_epoque_ens, bons_epoque_dis = 0, 0, 0
        for i in range(n_batch):
            bons_etu, bons_ens, bons_dis = forward_backward_pass(train_dataset_iter, enseignant_logit_model, etudiant_logit_model, alpha, temp, optim_etudiant, optim_enseignant, regularisation_etudiant, regularisation_enseignant)
            bons_epoque_etu += bons_etu.numpy()
            bons_epoque_ens += bons_ens.numpy()
            bons_epoque_dis += bons_dis.numpy()
            n_observ = (i+1) * batch_size
            accuracy_etu, accuracy_ens, accuracy_dis = bons_epoque_etu / n_observ, bons_epoque_ens / n_observ, bons_epoque_dis / n_observ
            barre_progression.update(i + 1, values = [("acc (etu, train)", accuracy_etu), ("acc (ens, train)", accuracy_ens), ("acc (dis, train)", accuracy_dis)])
        val_accuracy_etu, val_accuracy_ens, val_accuracy_dis = val_accuracies(etudiant, enseignant, validation_dataset)
        if val_accuracy_etu > val_accuracy_etu_max:
            val_accuracy_etu_max = val_accuracy_etu
            etudiant.save(f"sauvegardes/{nom_modele}_checkpoint.keras")
        if val_accuracy_ens > val_accuracy_ens_max:
            val_accuracy_ens_max = val_accuracy_ens
        if val_accuracy_dis > val_accuracy_dis_max:
            val_accuracy_dis_max = val_accuracy_dis
        if epoch + 1 == 70:
            optim_etudiant.learning_rate.assign(0.0001)
            optim_enseignant.learning_rate.assign(0.0001)
            enseignant.trainable = True
        if epoch + 1 in [10,20,50,70,100]:
            print(f"---> Epoque {epoch + 1:d} - Max val accuracy -> etu : {val_accuracy_etu_max:.4f} | ens : {val_accuracy_ens_max:.4f} | dis : {val_accuracy_dis_max:.4f}")
        append_csv_log(f"sauvegardes/{nom_modele}_logs.csv", epoch, accuracy_etu, accuracy_ens, accuracy_dis, val_accuracy_etu, val_accuracy_ens, val_accuracy_dis)
        print(f"Accuracy (etu, val) : {val_accuracy_etu:.4f} | Accuracy (ens, val) : {val_accuracy_ens:.4f} | Accuracy (dis, val) : {val_accuracy_dis:.4f}")

In [7]:
def distillation_resnet18_en_ligne(temp, alpha):
    clear_session()
    train_dataset, validation_dataset = load_cifar_train_val()
    modele = new_modele_resnet()
    nom_modele =  f"model_etudiant_ligne_t{temp:d}"
    distillateur_kl_en_ligne(modele, model_enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch, alpha)
    wd = os.getcwd()
    os.system(f"cp {wd}/sauvegardes/{nom_modele}_checkpoint.keras {wd}/sauvegardes/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}.keras s3/afeldmann/projet_cnam/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}_logs.csv s3/afeldmann/projet_cnam/{nom_modele}_logs.csv")

In [8]:
def ResNet18():
    resnet18, preprocess_input = Classifiers.get('resnet18')
    resnet = resnet18((224, 224, 3), weights='imagenet', include_top=False)
    resnet_output = GlobalAveragePooling2D()(resnet.output)
    resnet = Model(inputs=resnet.input, outputs=resnet_output)
    return resnet

def new_modele_resnet():
    model = Sequential([
        Input((224,224,3)),
        ResNet18(),
        Dropout(0.25),
        Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
        Dropout(0.5),
        Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
    ])
    model.compile(metrics=['accuracy'])
    return model

In [10]:
wd = os.getcwd()
os.system(f"mkdir -p {wd}/sauvegardes")
distillation_resnet18_en_ligne(3, 0.5)

C'est parti pour la distillation !

Époque 1 / 100


KeyboardInterrupt: 