In [1]:
!pip install matplotlib image_classifiers tqdm
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.applications import ResNet50V2
from keras.datasets import cifar100
from keras import Sequential, Input
from keras.layers import Dense, UpSampling2D, Dropout, RandomFlip, RandomTranslation, RandomRotation,RandomBrightness, RandomContrast, RandomZoom, GlobalAveragePooling2D
from keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
from keras.applications.resnet_v2 import preprocess_input
from keras.models import Model
from classification_models.keras import Classifiers
from keras.optimizers import Adam
from keras.activations import linear
from tqdm.notebook import tqdm
from keras.utils import Progbar
from tensorflow.nn import softmax_cross_entropy_with_logits
import os



2024-04-13 18:13:39.974127: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
n_epoch = 40
batch_size = 64
taux_validation = 0.1
num_classes = 100
n_images = 50000 # Pour l'entrainement, et 10000 pour le test

In [3]:
!mc cp s3/afeldmann/projet_cnam/modele_enseignant.keras /home/onyxia/work/projet_distillation_cnam/sauvegardes/modele_enseignant.keras
model_enseignant = Sequential([
    Input((224,224,3)),
    ResNet50V2(include_top=False, weights='imagenet', pooling="avg"),
    Dropout(0.25),
    Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
    Dropout(0.5),
    Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
])
# Keras 3.0 est buggé et le chargement direct ne marche pas ici, même si les poids sont bien enregistrés
model_enseignant.load_weights("/home/onyxia/work/projet_distillation_cnam/sauvegardes/modele_enseignant.keras")

model_enseignant.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

...nant.keras: 135.73 MiB / 135.73 MiB ┃▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓┃ 43.06 MiB/s 3s[0;22m[0m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m[m[32;1m

2024-04-13 18:13:51.011935: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13775 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:3b:00.0, compute capability: 7.5


In [4]:
def preprocessing(image, label):
    image = tf.image.resize(image, (224, 224))
    label = tf.squeeze(tf.one_hot(label, depth = num_classes), axis = 0)
    return  image, label

augmentation_donnees_keras = Sequential([
    RandomFlip("horizontal"),
    RandomTranslation(0.2,0.2),
    RandomRotation(0.2),
    RandomZoom(0.2),
    RandomContrast(0.2),
    RandomBrightness(0.2,value_range=(0,1))
])

def augmentation_donnees(image, label):
    return augmentation_donnees_keras(image/255.0, training = True)*255.0, label

def preprocess_resnet(image, label):
    return preprocess_input(image), label

def train_val_split(train_dataset, validation_size):
    X_train, y_train = train_dataset
    indices = np.random.permutation(X_train.shape[0])
    train_idx, val_idx = indices[:train_size], indices[train_size:]
    return (X_train[train_idx,...], y_train[train_idx,...]), (X_train[val_idx,...], y_train[val_idx,...])

train_dataset, test_dataset = cifar100.load_data()

validation_size = int(n_images * taux_validation)
train_size = n_images - validation_size

train_dataset, validation_dataset = train_val_split(train_dataset, validation_size)

validation_dataset = tf.data.Dataset.from_tensor_slices(validation_dataset).map(preprocessing).batch(batch_size).map(preprocess_resnet).cache().prefetch(tf.data.AUTOTUNE)
train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset).map(preprocessing).cache().repeat().shuffle(train_size).batch(batch_size).map(augmentation_donnees, num_parallel_calls = tf.data.AUTOTUNE).map(preprocess_resnet, num_parallel_calls = tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_dataset).map(preprocessing).batch(batch_size).map(preprocess_resnet, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [11]:
def get_modele_logits(modele):
    config = modele.layers[-1].get_config()
    config['activation'] = linear
    config['name'] = 'logits'
    res = Model(inputs=modele.inputs, outputs=[Dense(**config)(modele.layers[-2].output)])
    res.layers[-1].set_weights([x.numpy() for x in modele.layers[-1].weights])
    res.compile(metrics=['accuracy'])
    return res

@tf.function
def compte_bons(x,y):
    return tf.reduce_sum(tf.cast(tf.equal(tf.argmax(x, axis = 1), tf.argmax(y, axis = 1)), tf.float32))

@tf.function
def softmax(logits, temp):
    expo = tf.exp(logits / temp)
    return expo / tf.reduce_sum(expo, axis = 1, keepdims=True)

@tf.function
def ce(x, y_logits, temp):
    return softmax_cross_entropy_with_logits(x, y_logits / temp) * temp**2

def init_csv_log(fichier):
    with open(fichier,'w') as file:
        file.write("epoch, accuracy,val_accuracy\n")
def append_csv_log(fichier, epoch, accuracy,val_accuracy):
    with open(fichier,'a') as file:
        file.write(f"{epoch:d},{accuracy:.2f},{val_accuracy:.2f}\n")

@tf.function
def forward_backward_pass(train_dataset_iter, etudiant_logit_model, alpha, temp, adam):
    X_batch, y_batch, enseignant_estim_softmax = next(train_dataset_iter)
    with tf.GradientTape() as tape:
        etudiant_estim_logit = etudiant_logit_model(X_batch, training = True)
        perte = alpha * softmax_cross_entropy_with_logits(y_batch,etudiant_estim_logit) + (1-alpha) * ce(enseignant_estim_softmax,etudiant_estim_logit, temp)
    grads = tape.gradient(perte, etudiant_logit_model.trainable_variables)
    adam.apply_gradients(zip(grads, etudiant_logit_model.trainable_variables))
    return compte_bons(etudiant_estim_logit,y_batch)

def distillateur_kl(etudiant, enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch, alpha):
    etudiant_logit_model = get_modele_logits(etudiant)
    enseignant_logit_model = get_modele_logits(enseignant)
    adam = Adam(learning_rate=0.001)
    init_csv_log(f"sauvegardes/{nom_modele}_logs.csv")
    print("C'est parti pour la distillation !\n")
    val_accuracy_max = 0
    train_dataset_iter = iter(
        train_dataset
        .map(lambda images, label: (images, label, softmax(enseignant_logit_model(images, training = False), temp)), num_parallel_calls = tf.data.AUTOTUNE)
        .prefetch(tf.data.AUTOTUNE)
    )
    for epoch in range(n_epoch):
        print(f"Époque {epoch + 1} / {n_epoch}")
        n_batch = train_size//batch_size
        barre_progression = Progbar(n_batch, stateful_metrics = ["acc"])
        bons_epoque = 0
        for i in range(n_batch):
            bons_epoque += forward_backward_pass(train_dataset_iter, etudiant_logit_model, alpha, temp, adam).numpy()
            accuracy = bons_epoque / ((i+1) * batch_size)
            barre_progression.update(i, values = [("acc",accuracy)])
        _, val_accuracy = etudiant.evaluate(validation_dataset, verbose = 0)
        if val_accuracy > val_accuracy_max:
            val_accuracy_max = val_accuracy
            etudiant.save(f"sauvegardes/{nom_modele}_checkpoint.keras")
        append_csv_log(f"sauvegardes/{nom_modele}_logs.csv", epoch, accuracy, val_accuracy)
        print(f"Accuracy (train) : {accuracy:.4f} | Accuracy (val) : {val_accuracy:.4f}")

In [12]:
def ResNet18():
    resnet18, preprocess_input = Classifiers.get('resnet18')
    resnet = resnet18((224, 224, 3), weights='imagenet', include_top=False)
    resnet_output = GlobalAveragePooling2D()(resnet.output)
    resnet = Model(inputs=resnet.input, outputs=resnet_output)
    return resnet

def new_modele_resnet():
    model = Sequential([
        Input((224,224,3)),
        ResNet18(),
        Dropout(0.25),
        Dense(256, activation="sigmoid", kernel_regularizer = tf.keras.regularizers.L1(0.001)),
        Dropout(0.5),
        Dense(num_classes, activation="softmax", kernel_regularizer = tf.keras.regularizers.L2(0.001))
    ])
    model.compile(metrics=['accuracy'])
    return model

In [13]:
def distillation_resnet18(temp, alpha):
    tf.keras.backend.clear_session()
    modele = new_modele_resnet()
    nom_modele =  f"model_etudiant_t{temp:d}_a{int(alpha*100):d}"
    distillateur_kl(modele, model_enseignant, train_dataset, validation_dataset, temp, nom_modele, n_epoch,0.25)
    wd = os.getcwd()
    os.system(f"cp {wd}/sauvegardes/{nom_modele}_checkpoint.keras {wd}/sauvegardes/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}.keras s3/afeldmann/projet_cnam/{nom_modele}.keras")
    os.system(f"mc cp {wd}/sauvegardes/{nom_modele}_logs.csv s3/afeldmann/projet_cnam/{nom_modele}_logs.csv")

In [14]:
distillation_resnet18(1,1) # témoin

C'est parti pour la distillation !

Époque 1 / 40


2024-04-13 18:17:32.596543: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8900


[1m196/703[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m5:28[0m 648ms/step - acc: 0.0567

KeyboardInterrupt: 

In [None]:
distillation_resnet18(1,0.25)

C'est parti pour la distillation !

Époque 1 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

2024-04-13 07:51:40.492310: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8900
I0000 00:00:1712995075.180074   24606 service.cc:145] XLA service 0x7f879400bd60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1712995075.181323   24606 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
2024-04-13 07:57:55.339686: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1712995081.704745   24606 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Accuracy (train) : 0.1365 | Accuracy (val) : 0.0680
Époque 2 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.2850 | Accuracy (val) : 0.3194
Époque 3 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.3601 | Accuracy (val) : 0.3578
Époque 4 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4172 | Accuracy (val) : 0.4044
Époque 5 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4514 | Accuracy (val) : 0.4178
Époque 6 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4776 | Accuracy (val) : 0.4166
Époque 7 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4993 | Accuracy (val) : 0.4724
Époque 8 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5265 | Accuracy (val) : 0.5054
Époque 9 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5394 | Accuracy (val) : 0.4740
Époque 10 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5544 | Accuracy (val) : 0.4984
Époque 11 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5673 | Accuracy (val) : 0.5084
Époque 12 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5739 | Accuracy (val) : 0.5508
Époque 13 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5925 | Accuracy (val) : 0.5366
Époque 14 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6020 | Accuracy (val) : 0.5122
Époque 15 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6131 | Accuracy (val) : 0.5500
Époque 16 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6221 | Accuracy (val) : 0.5508
Époque 17 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6325 | Accuracy (val) : 0.5384
Époque 18 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

In [9]:
distillation_resnet18(3,0.25)

C'est parti pour la distillation !

Époque 1 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.1568 | Accuracy (val) : 0.1498
Époque 2 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.3091 | Accuracy (val) : 0.3588
Époque 3 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.3872 | Accuracy (val) : 0.3830
Époque 4 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4420 | Accuracy (val) : 0.4290
Époque 5 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4730 | Accuracy (val) : 0.4632
Époque 6 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4961 | Accuracy (val) : 0.4596
Époque 7 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5237 | Accuracy (val) : 0.4908
Époque 8 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5483 | Accuracy (val) : 0.4782
Époque 9 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5606 | Accuracy (val) : 0.5162
Époque 10 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5756 | Accuracy (val) : 0.5044
Époque 11 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.5877 | Accuracy (val) : 0.5406
Époque 12 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6017 | Accuracy (val) : 0.5138
Époque 13 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6099 | Accuracy (val) : 0.5474
Époque 14 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6219 | Accuracy (val) : 0.5526
Époque 15 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6306 | Accuracy (val) : 0.5468
Époque 16 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6405 | Accuracy (val) : 0.5604
Époque 17 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6462 | Accuracy (val) : 0.5458
Époque 18 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6537 | Accuracy (val) : 0.5760
Époque 19 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6628 | Accuracy (val) : 0.5510
Époque 20 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6687 | Accuracy (val) : 0.5654
Époque 21 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6765 | Accuracy (val) : 0.5712
Époque 22 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6797 | Accuracy (val) : 0.5874
Époque 23 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6879 | Accuracy (val) : 0.5782
Époque 24 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6938 | Accuracy (val) : 0.5796
Époque 25 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.6993 | Accuracy (val) : 0.5828
Époque 26 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7029 | Accuracy (val) : 0.6058
Époque 27 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7064 | Accuracy (val) : 0.5962
Époque 28 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7109 | Accuracy (val) : 0.6032
Époque 29 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7169 | Accuracy (val) : 0.5948
Époque 30 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7234 | Accuracy (val) : 0.5918
Époque 31 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7277 | Accuracy (val) : 0.5946
Époque 32 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.7324 | Accuracy (val) : 0.5914
Époque 33 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

`/home/onyxia/work/projet_distillation_cnam/sauvegardes/model_etudiant_t3_a25.keras` -> `s3/afeldmann/projet_cnam/model_etudiant_t3_a25.keras`
Total: 43.54 MiB, Transferred: 43.54 MiB, Speed: 32.19 MiB/s
`/home/onyxia/work/projet_distillation_cnam/sauvegardes/model_etudiant_t3_a25_logs.csv` -> `s3/afeldmann/projet_cnam/model_etudiant_t3_a25_logs.csv`
Total: 435 B, Transferred: 435 B, Speed: 2.45 KiB/s


In [10]:
distillation_resnet18(8,0.25)

C'est parti pour la distillation !

Époque 1 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.1656 | Accuracy (val) : 0.1028
Époque 2 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.3126 | Accuracy (val) : 0.3260
Époque 3 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.3862 | Accuracy (val) : 0.3714
Époque 4 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

Accuracy (train) : 0.4367 | Accuracy (val) : 0.3942
Époque 5 / 100


  0%|          | 0/703 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
distillation_resnet18(1,0.5)

In [None]:
distillation_resnet18(3,0.5)

In [None]:
distillation_resnet18(8,0.5)

In [5]:
def graphiques_accuracy(temp, alpha):
    nom_modele =  f"model_etudiant_t{temp:d}_a{int(alpha*100):d}"
    history=np.genfromtxt(f"sauvegardes/{nom_modele}_logs.csv", delimiter=",", names = True)
    plt.plot(history['accuracy'])
    plt.plot(history['val_accuracy'])
    plt.title('Modèle enseignant')
    plt.ylabel('Exactitude')
    plt.xlabel('Époque')
    plt.axvline(x=47, color='purple', ls='--', lw=2, label='Limite réglage fin')
    plt.legend(['Entrainement', 'Validation','Limite réglage fin'], loc='best')
    plt.show()