In [None]:
# !pip install tensorflow-model-optimization
# !pip install kornia
# !pip install hls4ml

In [4]:
PATH_TO_TRAIN_DIR = "./dataset/train"
PATH_TO_TEST_DIR = "./dataset/test"
# PATH_TO_CKPT_DIR = "./models/SpineNet-49S" # Ensure This. If needed to be changed, modify first line of weights.py
PATH_TO_MOD_DIR = "./models"
ALPHA = 0.5

In [1]:
import os
import tensorflow as tf
import keras
# import tensorflow.compat.v1 as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Flatten, Dense, UpSampling2D, BatchNormalization, Activation, Add, Concatenate
from tensorflow.keras.models import Model, load_model
import numpy as np

# import tensorflow_model_optimization as tfmot
keras.backend.clear_session()
# import hls4ml
#os.environ['PATH'] = os.environ['XILINX_VIVADO'] + '/bin:' + os.environ['PATH']
os.environ['PATH'] = '/tools/Xilinx/Vivado/2019.1/bin:' + os.environ['PATH']

import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as transforms
import gc
gc.collect()


from weights import *
from spinenet_functions import *
from student_teacher_functions import *
from spinenet import SpineNet
from student_teacher import Student, Teacher

In [None]:
##FOR VIEWING CKPT FILE ##

# for key,vale in shape.items():
#     # print(key,":",vale)
#     if key.startswith('retinanet'):
#         print(key,":",vale)

## SpineNet

In [3]:
keras.backend.clear_session()
model = SpineNet()
model.compile()
model.save(PATH_TO_MOD_DIR+'/SpineNet49S.h5', save_format = 'h5')

before max pool 1 (None, 640, 640, 64)
after max pool 2 (None, 640, 640, 64)
after stem 0 (None, 320, 320, 164)
after stem 1 (None, 320, 320, 164)
after block 2 (None, 320, 320, 164)
after block 3 (None, 80, 80, 166)
after block 4 (None, 160, 160, 332)
after block 5 (None, 80, 80, 664)
after block 6 (None, 20, 20, 166)
after block 7 (None, 80, 80, 664)
after block 8 (None, 40, 40, 166)
after block 9 (None, 10, 10, 166)
after block 10 (None, 40, 40, 664)
after block 11 (None, 40, 40, 664)
after block 12 (None, 80, 80, 664)
output 12 (None, 80, 80, 256)
after block 13 (None, 160, 160, 332)
output 13 (None, 160, 160, 256)
after block 14 (None, 40, 40, 664)
output 14 (None, 40, 40, 256)
after block 15 (None, 10, 10, 664)
output 15 (None, 10, 10, 256)
after block 16 (None, 20, 20, 664)
output 16 (None, 20, 20, 256)


In [None]:
#### RUN IF SPINENET49S ALREADY SAVED ####
model = keras.models.load_model(PATH_TO_MOD_DIR+'/SpineNet49S.h5')

## STUDENT & TEACHER

In [None]:
keras.backend.clear_session()
student = Student()
student.compile()
student.save(PATH_TO_MOD_DIR+'/Student.h5', save_format = 'h5')

keras.backend.clear_session()
teacher = Teacher()
teacher.compile()
teacher.save(PATH_TO_MOD_DIR+'/Teacher.h5', save_format = 'h5')

In [None]:
#### RUN IF STUDENT+TEACHER ALREADY SAVED ####
student = keras.models.load_model(PATH_TO_MOD_DIR+'/Student.h5')
teacher = keras.models.load_model(PATH_TO_MOD_DIR+'/Teacher.h5')

## DATA

In [None]:
x_train = []

for dirpath, dirnames, filenames in os.walk(PATH_TO_TRAIN_DIR):
    for file in filenames:
        # if file.lower().endswith(('.png')):
        path = os.path.join(dirpath, file)
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (640,640))
        # img = cv2.normalize(img, None, 0, 1.0,
        # cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        # img = transforms.ToTensor()(img)
        x_train.append(img)
        del img
    gc.collect()

x_train = np.array(x_train)

In [None]:
x_test = []

for dirpath, dirnames, filenames in os.walk(PATH_TO_TEST_DIR):
    for file in filenames:
        # if file.lower().endswith(('.png')):
        path = os.path.join(dirpath, file)
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (640,640))
        # img = cv2.normalize(img, None, 0, 1.0,
        # cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        # img = transforms.ToTensor()(img)
        x_test.append(img)
        del img
    gc.collect()

x_test = np.array(x_test)

## KNOWLEDGE DISTILLATION

In [None]:
from kornia.losses import ssim_loss
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):

        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        # teacher_pred = self.teacher.predict(x)

        # student_loss = self.student_loss_fn(y, y_pred)

        # distillation_loss = self.distillation_loss_fn(
        #     ops.softmax(teacher_pred / self.temperature, axis=1),
        #     ops.softmax(y_pred / self.temperature, axis=1),
        #     window_size = 5
        # ) * (self.temperature**2)

        # teacher_pred = transforms.ToTensor()(tf.make_ndarray(teacher_pred))
        # y_pred = transforms.ToTensor()(tf.make_ndarray(y_pred))

        # print(teacher_pred)
        # print(y_pred)

        distillation_loss = self.distillation_loss_fn(
            teacher_pred,
            y_pred,
            max_val = 255
        )
        # distillation_loss = tf.keras.losses.MSE(teacher_pred, y_pred)
        # loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return distillation_loss

    def call(self, x):
        return self.student(x)

In [None]:
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics='mse',
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.image.ssim
)

In [None]:
distiller.fit(x_train, epochs=10, batch_size = 1)

In [None]:
distiller.student.save(PATH_TO_MOD_DIR+'/Student_trained.h5', save_format='h5')

In [None]:
#### RUN IF STUDENT IS ALREADY TRAINED ####
student_trained = keras.models.load_model(PATH_TO_MOD_DIR+'/Student_trained.h5')

In [None]:
y_true = teacher.predict(x_test)
y_pred = student_trained.predict(x_test)

In [None]:
### This returns an array of ssim losses for each image ###
loss = tf.image.ssim(y_true, y_pred, 255)
loss