<a href="https://colab.research.google.com/github/NaHyeonMaeng/CODE_Practice/blob/main/Knowledge_Distillation_(%EC%A7%80%EC%8B%9D_%EC%A6%9D%EB%A5%98).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#라이브러리 불러오기
import tensorflow as tf
import numpy as np
from google.colab.patches import cv2_imshow
from tqdm import tqdm

In [2]:
#주요 파라미터 설정

#@title 파라미터 설정
t_ephoc = 10 #@param {type:"slider", min:1, max:100, step:1}
s_ephoc = 5 #@param {type:"slider", min:1, max:100, step:1}
learning_rate = 0.01
batch_size = 64 #@param [32, 64, 128, 256] {type:"raw"}
temperature = 3 #@param {type:"slider", min:1, max:10, step:1}
alpha = 0.5 #@param {type:"slider", min:0.1, max:0.9, step:0.1}

In [3]:
# MNIST 데이터셋 가져오기

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


모델 정의

In [4]:
# teacher 모델
i=tf.keras.Input(shape=(28, 28, 1))
out=tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same")(i)
out=tf.keras.layers.LeakyReLU(alpha=0.2)(out)
out=tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same")(out)
out=tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same")(out)
out=tf.keras.layers.Flatten()(out)
out=tf.keras.layers.Dense(10)(out)
t_model=tf.keras.Model(inputs=[i],outputs=[out])

t_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 256)       2560      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 14, 14, 256)       0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 256)      0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 512)         1180160   
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                             

In [5]:
# student 모델
i=tf.keras.Input(shape=(28, 28, 1))
out=tf.keras.layers.Flatten()(i)
out=tf.keras.layers.Dense(28)(out)
out=tf.keras.layers.Dense(10)(out)

s_model_1=tf.keras.Model(inputs=[i],outputs=[out])
s_model_2=tf.keras.models.clone_model(s_model_1)  #성능 비교를 위해 똑같은 구조를 갖는 모델 복제

s_model_1.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 flatten_1 (Flatten)         (None, 784)               0         
                                                                 
 dense_1 (Dense)             (None, 28)                21980     
                                                                 
 dense_2 (Dense)             (None, 10)                290       
                                                                 
Total params: 22,270
Trainable params: 22,270
Non-trainable params: 0
_________________________________________________________________


모델 컴파일

In [6]:
# teacher 모델
t_model.compile(tf.keras.optimizers.Adam(learning_rate),
                tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# student 모델 (distillation 적용)
s_model_1.compile(tf.keras.optimizers.Adam(learning_rate),
                tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# 비교 모델 (distillation 미적용)
s_model_2.compile(tf.keras.optimizers.Adam(learning_rate),
                tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

모델 훈련

In [7]:
# teacher 모델
t_model.fit(x_train, y_train,batch_size=batch_size,epochs=t_ephoc)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f769c09ee90>

In [8]:
#손실함수 정의

# student loss
s_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# distillation loss
d_loss = tf.keras.losses.KLDivergence()

In [9]:
#훈련 셋의 배열 크기 확인 (이미지 개수, 가로, 세로)
x_train.shape

(60000, 28, 28, 1)

In [10]:
#모델 학습
batch_count =x_train.shape[0]//batch_size

opt = tf.keras.optimizers.legacy.Adam(learning_rate)

for e in range(s_ephoc):
    for _ in range(batch_count):
        batch_num=np.random.randint(0, x_train.shape[0], size=batch_size)
        t_pred = t_model.predict(x_train[batch_num])

        with tf.GradientTape() as tape:
            s_pred_1 = s_model_1(x_train[batch_num])
            student_loss = s_loss(y_train[batch_num], s_pred_1)
            distillation_loss = d_loss(
                tf.nn.softmax(t_pred / temperature, axis=1),
                tf.nn.softmax(s_pred_1 / temperature, axis=1),
            )
            loss = alpha * student_loss + (1 - alpha) * distillation_loss  #student loss와 distillation loss를 가중평균한 값

        vars = s_model_1.trainable_variables
        grad = tape.gradient(loss, vars)
        opt.apply_gradients(zip(grad, vars))

        with tf.GradientTape() as tape:
            s_pred_2 = s_model_2(x_train[batch_num])
            student_loss = s_loss(y_train[batch_num], s_pred_2)
        vars = s_model_2.trainable_variables
        grad = tape.gradient(student_loss, vars)
        opt.apply_gradients(zip(grad, vars))

#모델 평가
    print("에포크 {}".format(e))
    print("선생님께 배운 경우")
    s_model_1.fit(x_test, y_test)
    print("혼자 공부한 경우")
    s_model_2.fit(x_test, y_test)
    print("\n")

에포크 0
선생님께 배운 경우
혼자 공부한 경우


에포크 1
선생님께 배운 경우
혼자 공부한 경우


에포크 2
선생님께 배운 경우
혼자 공부한 경우


에포크 3
선생님께 배운 경우
혼자 공부한 경우


에포크 4
선생님께 배운 경우
혼자 공부한 경우


