In [1]:
!pip install matplotlib image_classifiers tqdm
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.applications import ResNet50V2
from keras.datasets import cifar100
from keras import Sequential, Input
from keras.layers import Dense, UpSampling2D, Dropout, RandomFlip, RandomTranslation, RandomRotation,RandomBrightness, RandomContrast, RandomZoom, GlobalAveragePooling2D
from keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
from keras.applications.resnet_v2 import preprocess_input
from keras.models import Model
from classification_models.keras import Classifiers
from keras.optimizers import Adam
from keras.activations import linear
from tqdm.notebook import tqdm
import os



2024-04-12 08:48:01.135805: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
n_epoch = 100 # Il y a le early stopping
batch_size = 100
taux_validation = 0.1
num_classes = 100
n_images = 50000 # Pour l'entrainement, et 10000 pour le test

In [3]:
!mc cp s3/afeldmann/projet_cnam/modele_enseignant.keras /home/onyxia/work/projet_distillation_cnam/sauvegardes/modele_enseignant.keras
model_enseignant = Sequential([
    Input((224,224,3)),
    ResNet50V2(include_top=False, weights='imagenet', pooling="avg"),
    Dropout(0.25),
    Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
    Dropout(0.5),
    Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
])
# Keras 3.0 est buggé et le chargement direct ne marche pas ici, même si les poids sont bien enregistrés
model_enseignant.load_weights("/home/onyxia/work/projet_distillation_cnam/sauvegardes/modele_enseignant.keras")

model_enseignant.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

...nant.keras: 135.73 MiB / 135.73 MiB ┃▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓┃ 68.18 MiB/s 1s[0;22m[0m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m

2024-04-12 08:48:11.219065: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13775 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:86:00.0, compute capability: 7.5


In [4]:
def preprocessing(image, label):
    image = tf.image.resize(image, (224, 224))
    label = tf.squeeze(tf.one_hot(label, depth = num_classes), axis = 0)
    return  image, label

augmentation_donnees_keras = Sequential([
    RandomFlip("horizontal"),
    RandomTranslation(0.2,0.2),
    RandomRotation(0.2),
    RandomZoom(0.2),
    RandomContrast(0.2),
    RandomBrightness(0.2,value_range=(0,1))
])

def augmentation_donnees(image, label):
    return augmentation_donnees_keras(image/255.0, training = True)*255.0, label

def preprocess_resnet(image, label):
    return preprocess_input(image), label

def train_val_split(train_dataset, validation_size):
    X_train, y_train = train_dataset
    indices = np.random.permutation(X_train.shape[0])
    train_idx, val_idx = indices[:train_size], indices[train_size:]
    return (X_train[train_idx,...], y_train[train_idx,...]), (X_train[val_idx,...], y_train[val_idx,...])

train_dataset, test_dataset = cifar100.load_data()

validation_size = int(n_images * taux_validation)
train_size = n_images - validation_size

train_dataset, validation_dataset = train_val_split(train_dataset, validation_size)

validation_dataset = tf.data.Dataset.from_tensor_slices(validation_dataset).map(preprocessing).batch(batch_size).map(preprocess_resnet).cache().prefetch(tf.data.AUTOTUNE)
train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset).map(preprocessing).cache().repeat().shuffle(train_size).batch(batch_size).map(augmentation_donnees, num_parallel_calls = 2).map(preprocess_resnet, num_parallel_calls = 2).prefetch(2)
test_dataset = tf.data.Dataset.from_tensor_slices(test_dataset).map(preprocessing).batch(batch_size).map(preprocess_resnet, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [5]:
def get_modele_logits(modele):
    config = modele.layers[-1].get_config()
    config['activation'] = linear
    config['name'] = 'logits'
    res = Model(inputs=modele.inputs, outputs=[Dense(**config)(modele.layers[-2].output)])
    res.layers[-1].set_weights([x.numpy() for x in modele.layers[-1].weights])
    res.compile(metrics=['accuracy'])
    return res

@tf.function
def compte_bons(x,y):
    return tf.reduce_sum(tf.cast(tf.equal(tf.argmax(x, axis = 1), tf.argmax(y, axis = 1)), tf.float32))

@tf.function
def softmax(logits, temp):
    expo = tf.exp(logits / temp)
    return expo / tf.reduce_sum(expo, axis = 1, keepdims=True)

@tf.function
def ce(x, y, temp):
    res = - x * tf.math.log(y)
    res = tf.where(tf.math.is_nan(res), 0., res)
    res = tf.reduce_sum(res) * temp**2
    return res

def init_csv_log(fichier):
    with open(fichier,'w') as file:
        file.write("epoch, accuracy,val_accuracy\n")
def append_csv_log(fichier, epoch, accuracy,val_accuracy):
    with open(fichier,'a') as file:
        file.write(f"{epoch:d},{accuracy:.2f},{val_accuracy:.2f}\n")

def distillateur_kl(etudiant, enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch, alpha):
    etudiant_logit = get_modele_logits(etudiant)
    enseignant_logit = get_modele_logits(enseignant)
    adam = Adam(learning_rate=0.001)
    init_csv_log(f"sauvegardes/{nom_modele}.csv")
    print("C'est parti pour la distillation !\n")
    val_accuracy_max = 0
    val_loss_min = 0
    early_stop_count = 0
    train_dataset_iter = iter(train_dataset)
    for epoch in range(n_epoch):
        print(f"Époque {epoch + 1} / {n_epoch}")
        n_batch = train_size//batch_size
        barre_progression = tqdm(range(n_batch))
        bons_epoque = 0
        for i in barre_progression:
            X_batch, y_batch = next(train_dataset_iter)
            enseignant_estim_logit = enseignant_logit(X_batch, training = False)
            enseignant_estim_softmax = softmax(enseignant_estim_logit, temp)
            with tf.GradientTape() as tape:
                etudiant_estim_logit = etudiant_logit(X_batch, training = True)
                etudiant_estim_softmax = softmax(etudiant_estim_logit, temp)
                etudiant_estim_softmax_1 = softmax(etudiant_estim_logit, 1)
                perte = alpha * ce(y_batch,etudiant_estim_softmax, 1) + (1-alpha) * ce(enseignant_estim_softmax,etudiant_estim_softmax, temp)
            grads = tape.gradient(perte, etudiant_logit.trainable_variables)
            adam.apply_gradients(zip(grads, etudiant_logit.trainable_variables))
            bons_epoque += compte_bons(etudiant_estim_softmax,y_batch).numpy()
            accuracy = bons_epoque / (i * batch_size) if i != 0 else np.nan
            barre_progression.set_description(f"Accuracy {accuracy*100:.1f} %")
        val_loss, val_accuracy = etudiant.evaluate(validation_dataset)
        if val_accuracy > val_accuracy_max:
            val_accuracy_max = val_accuracy
            etudiant.save(f"sauvegardes/{nom_modele}_checkpoint.keras")
        if val_loss < val_loss_min:
            val_loss_min = val_loss
            early_stop_count = 0
        else if early_stop_count > 5:
            return
        else early_stop_count += 1
        append_csv_log(f"sauvegardes/{nom_modele}.csv", epoch, accuracy, val_accuracy)
        print(f"Accuracy (train) : {accuracy:.4f} | Accuracy (val) : {val_accuracy:.4f}")

In [6]:
def ResNet18():
    resnet18, preprocess_input = Classifiers.get('resnet18')
    resnet = resnet18((224, 224, 3), weights='imagenet', include_top=False)
    resnet_output = GlobalAveragePooling2D()(resnet.output)
    resnet = Model(inputs=resnet.input, outputs=resnet_output)
    return resnet

def new_modele_resnet():
    model = Sequential([
        Input((224,224,3)),
        ResNet18(),
        Dropout(0.25),
        Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
        Dropout(0.5),
        Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
    ])
    model.compile(metrics=['accuracy'])
    return model

In [7]:
def distillation_resnet18(temp, alpha):
    tf.keras.backend.clear_session()
    modele = new_modele_resnet()
    nom_modele =  f"model_etudiant_t{temp:d}_a{int(alpha*100):d}"
    distillateur_kl(modele, model_enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch,0.25)
    wd = os.getcwd()
    os.system(f"cp {wd}/sauvegardes/{nom_modele}_checkpoint.keras {wd}/sauvegardes/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}.keras s3/afeldmann/projet_cnam/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}.csv s3/afeldmann/projet_cnam/{nom_modele}.csv")
    history=np.genfromtxt(f"sauvegardes/{nom_modele}_logs.csv", delimiter=",", names = True)
    plt.plot(history['accuracy'])
    plt.plot(history['val_accuracy'])
    plt.title('Modèle enseignant')
    plt.ylabel('Exactitude')
    plt.xlabel('Époque')
    plt.axvline(x=47, color='purple', ls='--', lw=2, label='Limite réglage fin')
    plt.legend(['Entrainement', 'Validation','Limite réglage fin'], loc='best')
    plt.show()
    plt.plot(history['loss'])
    plt.plot(history['val_loss'])
    plt.title('Modèle enseignant')
    plt.ylabel('Perte')
    plt.xlabel('Époque')
    plt.axvline(x=47, color='purple', ls='--', lw=2, label='Limite réglage fin')
    plt.legend(['Entrainement', 'Validation','Limite réglage fin'], loc='best')
    plt.show()

In [8]:
distillation_resnet18(1,0.25)

C'est parti pour la distillation !

Époque 1 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

2024-04-12 08:48:32.652128: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8900
I0000 00:00:1712912230.418134   46289 service.cc:145] XLA service 0x557aa6c7d050 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1712912230.418970   46289 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
2024-04-12 08:57:11.748946: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-04-12 08:57:13.739634: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng1{k2=4,k3=0} for conv (f32[100,64,112,112]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,3,224,224]{3,2,1,0}, f32[64,3,7,7]{3,2,1,0}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_que

[1m 2/50[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4s[0m 97ms/step - accuracy: 0.0525 - loss: 6.6025  

I0000 00:00:1712912244.098426   46289 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 100ms/step - accuracy: 0.0309 - loss: 6.6025
Accuracy (train) : 0.1822 | Accuracy (val) : 0.0276
Époque 2 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.3007 - loss: 7.0919
Accuracy (train) : 0.3419 | Accuracy (val) : 0.2976
Époque 3 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.4171 - loss: 7.4762
Accuracy (train) : 0.4082 | Accuracy (val) : 0.4148
Époque 4 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 106ms/step - accuracy: 0.4343 - loss: 7.8241
Accuracy (train) : 0.4590 | Accuracy (val) : 0.4352
Époque 5 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 101ms/step - accuracy: 0.4560 - loss: 8.1359
Accuracy (train) : 0.4894 | Accuracy (val) : 0.4460
Époque 6 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 105ms/step - accuracy: 0.5124 - loss: 8.4170
Accuracy (train) : 0.5212 | Accuracy (val) : 0.5098
Époque 7 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.5059 - loss: 8.6802
Accuracy (train) : 0.5414 | Accuracy (val) : 0.5070
Époque 8 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.5141 - loss: 8.9262
Accuracy (train) : 0.5632 | Accuracy (val) : 0.5116
Époque 9 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.5420 - loss: 9.1738
Accuracy (train) : 0.5789 | Accuracy (val) : 0.5404
Époque 10 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 108ms/step - accuracy: 0.5488 - loss: 9.4018
Accuracy (train) : 0.5953 | Accuracy (val) : 0.5428
Époque 11 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 106ms/step - accuracy: 0.5673 - loss: 9.6228
Accuracy (train) : 0.6019 | Accuracy (val) : 0.5570
Époque 12 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 100ms/step - accuracy: 0.5759 - loss: 9.8313
Accuracy (train) : 0.6191 | Accuracy (val) : 0.5760
Époque 13 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.5841 - loss: 10.0385
Accuracy (train) : 0.6287 | Accuracy (val) : 0.5820
Époque 14 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.5807 - loss: 10.2348
Accuracy (train) : 0.6441 | Accuracy (val) : 0.5730
Époque 15 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.5960 - loss: 10.4337
Accuracy (train) : 0.6485 | Accuracy (val) : 0.5946
Époque 16 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 100ms/step - accuracy: 0.5919 - loss: 10.6193
Accuracy (train) : 0.6582 | Accuracy (val) : 0.5838
Époque 17 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 106ms/step - accuracy: 0.5900 - loss: 10.8066
Accuracy (train) : 0.6634 | Accuracy (val) : 0.5828
Époque 18 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.6174 - loss: 10.9907
Accuracy (train) : 0.6695 | Accuracy (val) : 0.6082
Époque 19 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 105ms/step - accuracy: 0.6108 - loss: 11.1731
Accuracy (train) : 0.6803 | Accuracy (val) : 0.6064
Époque 20 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.6251 - loss: 11.3448
Accuracy (train) : 0.6880 | Accuracy (val) : 0.6166
Époque 21 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 108ms/step - accuracy: 0.6169 - loss: 11.5197
Accuracy (train) : 0.6907 | Accuracy (val) : 0.6062
Époque 22 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 103ms/step - accuracy: 0.6402 - loss: 11.6961
Accuracy (train) : 0.6976 | Accuracy (val) : 0.6312
Époque 23 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 101ms/step - accuracy: 0.6259 - loss: 11.8716
Accuracy (train) : 0.7020 | Accuracy (val) : 0.6200
Époque 24 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 105ms/step - accuracy: 0.5935 - loss: 12.0298
Accuracy (train) : 0.7047 | Accuracy (val) : 0.5862
Époque 25 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 100ms/step - accuracy: 0.6182 - loss: 12.1864
Accuracy (train) : 0.7163 | Accuracy (val) : 0.6096
Époque 26 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 98ms/step - accuracy: 0.6244 - loss: 12.3447
Accuracy (train) : 0.7191 | Accuracy (val) : 0.6124
Époque 27 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.6301 - loss: 12.4946
Accuracy (train) : 0.7268 | Accuracy (val) : 0.6310
Époque 28 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.6430 - loss: 12.6510
Accuracy (train) : 0.7247 | Accuracy (val) : 0.6322
Époque 29 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 98ms/step - accuracy: 0.6496 - loss: 12.8055
Accuracy (train) : 0.7348 | Accuracy (val) : 0.6430
Époque 30 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 100ms/step - accuracy: 0.6572 - loss: 12.9665
Accuracy (train) : 0.7352 | Accuracy (val) : 0.6458
Époque 31 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 102ms/step - accuracy: 0.6310 - loss: 13.1204
Accuracy (train) : 0.7410 | Accuracy (val) : 0.6204
Époque 32 / 100


  0%|          | 0/450 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
distillation_resnet18(3,0.25)

In [None]:
distillation_resnet18(8,0.25)

In [None]:
distillation_resnet18(1,0.5)

In [None]:
distillation_resnet18(3,0.5)

In [None]:
distillation_resnet18(8,0.5)