In [36]:
import os
import tensorflow as tf
import keras as keras
from keras import layers
from keras import ops
import numpy as np
from tensorflow.keras.models import Model
import pandas as pd
import sys

from vit_keras import vit

In [37]:
image_size = (512, 512)
batch_size = 2
accumulation_steps = 16 // batch_size # this simulates a batch size of batch_size*accumulation_steps = 16
input_shape = image_size + (3,)
learning_rate = 2e-4
epochs = 25
alpha = 0.75
beta = 0.1
temperature = 3.0
spatial_alignment_layers = 2
seed = 1337

Data Loader

In [None]:
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.utils import to_categorical

def load_data(image_size, batch_size):
    data_dir = '/mnt/c/Users/Ann Clarisse Salazar/Documents/project/train_zscore'

    labels_df = pd.read_csv('/mnt/c/Users/Ann Clarisse Salazar/Documents/project/data/train_labels.csv')
    labels_df['image'] = labels_df['image'].apply(lambda x: f"{data_dir}/{x}.jpeg")
    labels_df['level'] = labels_df['level'].astype(str)

    train_data_augmentation = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=360,  # Random rotation up to 90 degrees
        width_shift_range=0.2,  # Random horizontal shift
        height_shift_range=0.2,  # Random vertical shift
        zoom_range=[0.87, 1.15],
        brightness_range=[0.8, 1.2],
        horizontal_flip=True,  # Random horizontal flip
        vertical_flip=True,  # Random vertical flip
        validation_split=0.2,
        fill_mode='constant'
    )

    # Define data augmentation for validation (only rescaling)
    val_data_augmentation = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        validation_split=0.2,
    )

    # Load training dataset with augmentation
    train_ds = train_data_augmentation.flow_from_dataframe(
        dataframe=labels_df,
        x_col='image',
        y_col='level',
        target_size=image_size,
        batch_size=batch_size,
        subset='training',
        seed=seed,
        shuffle=True,
    )

    # Load validation dataset without augmentation
    val_ds = val_data_augmentation.flow_from_dataframe(
        dataframe=labels_df,
        x_col='image',
        y_col='level',
        target_size=image_size,
        batch_size=batch_size,
        subset='validation',
        seed=seed,
        shuffle=False
    )
    
    return train_ds, val_ds

# Load the data
train_ds, val_ds = load_data(image_size, batch_size)

Model Compilation

In [None]:
student = vit.vit_b16(
    weights='imagenet21k+imagenet2012',
    image_size = 512,
    pretrained=True,
    pretrained_top=False,
    classes=5,
)

data_dir = '/mnt/c/Users/Ann Clarisse Salazar/Documents/project'
model_folder= "FinalModels/teacher_resnet"
config_json = f'{data_dir}/{model_folder}/config.json'
model_path = f'{data_dir}/{model_folder}/model.weights.h5'
teacher = tf.keras.applications.ResNet50V2(
    include_top=False,
    weights=None,
    input_shape=(512, 512, 3),
    classes=5
)
x = teacher.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(5)(x)

teacher = Model(teacher.input, x)
teacher.load_weights(model_path)
teacher.trainable = False

In [None]:
from knowledge_distillation import Distiller

teacher.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy']
)

total_steps = epochs * 14051

cosine_decay_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=learning_rate, decay_steps=total_steps, alpha=2e-6
)
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate=cosine_decay_fn, gradient_accumulation_steps=accumulation_steps),
    metrics=[
        'accuracy',
    ],
    student_loss_fn=keras.losses.CategoricalCrossentropy(),
    logit_loss_fn=keras.losses.KLDivergence(),
    feature_loss_fn=keras.losses.CosineSimilarity(),
    alpha=alpha,
    beta=beta,
    temperature=temperature,
)

In [None]:
import datetime
import io
import os
import json
import sys
import shutil
from tensorflow.keras.callbacks import Callback, BackupAndRestore

# Define model_id and log_dir
model_id = "GAMEKD"+distiller.name + '_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
print(f"Model ID: {model_id}")
log_dir = f"/mnt/c/Users/Ann Clarisse Salazar/Documents/project/records_kd/{model_id}/logs"
os.makedirs(log_dir, exist_ok=True)

In [44]:
total_run_time = None
array = []

# if training is resumed
log_file = f'resume_epoch_logs_{model_id}.json'
if os.path.exists(log_file):
    with open(log_file, 'r') as f:
        array = json.load(f)
        print(array)
        
# Define a callback to log the final epoch
class FinalEpochLogger(Callback):
    def __init__(self, log_file=f'resume_epoch_logs_{model_id}.json'):
        super(FinalEpochLogger, self).__init__()
        self.log_file = log_file
        array = self.load_logs()
        self.total_time = self.load_elapsed_time()

    def on_train_begin(self, logs=None):
        self.time_start = datetime.datetime.now()

    def on_train_end(self, logs=None):
        self.update_elapsed_time()
        self.save_elapsed_time()

    def on_train_batch_end(self, batch, logs=None):
        self.update_elapsed_time()
        self.time_start = datetime.datetime.now()

    def update_elapsed_time(self):
        self.time_end = datetime.datetime.now()
        elapsed_time = (self.time_end - self.time_start).total_seconds()
        self.total_time += elapsed_time

    def save_elapsed_time(self):
        elapsed_time_file = self.log_file.replace('.json', '_elapsed_time.json')
        with open(elapsed_time_file, 'w') as f:
            json.dump(self.total_time, f)

    def load_elapsed_time(self):
        elapsed_time_file = self.log_file.replace('.json', '_elapsed_time.json')
        if os.path.exists(elapsed_time_file):
            with open(elapsed_time_file, 'r') as f:
                return json.load(f)
        return 0

    def on_epoch_end(self, epoch, logs=None):
        array.append(f"Epoch {epoch+1}: {logs}\n")
        self.save_logs()

    def save_logs(self):
        with open(self.log_file, 'w') as f:
            json.dump(array, f)

    def load_logs(self):
        if os.path.exists(self.log_file):
            with open(self.log_file, 'r') as f:
                return json.load(f)
        return []
final_epoch_logger = FinalEpochLogger()
backup_dir = './backup'
backup_callback = BackupAndRestore(backup_dir=backup_dir)

Fitting

In [None]:
distiller.build(input_shape)
distiller.fit(train_ds,
              epochs=epochs,
              validation_data=val_ds,
              # steps_per_epoch=steps_per_epoch,
              callbacks=[
                tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True, mode='max'),
                tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1),
                final_epoch_logger,
                backup_callback
                ])

In [None]:
total_runtime = final_epoch_logger.total_time
#make total runtime in HH: MM: SS format
total_runtime = str(datetime.timedelta(seconds=total_runtime))
print(f"Total runtime: {total_runtime} hours")
print(log_dir)
# Save the model
os.makedirs(f"/mnt/c/Users/Ann Clarisse Salazar/Documents/project/records_kd/{model_id}", exist_ok=True)
distiller.student.save(f'/mnt/c/Users/Ann Clarisse Salazar/Documents/project/records_kd/{model_id}/model_{model_id}.keras')

In [None]:
import sys

note = f"NOTE: cam_kd, {spatial_alignment_layers} spatial alignment layers, limited set, {beta} beta, use base vit with pretraining, no cams, base KD"

# Create a StringIO object to capture the output
output_capture = io.StringIO()

# Redirect stdout to the StringIO object
sys.stdout = output_capture
print(note + "\n\n")

# print training info
for element in array:
    print(element)

print("\n\n")

# print model info
print("HYPERPARAMETERS")
print(f"Model Name: {distiller.name}")
print(f"Epochs: {epochs}")
print(f"Batch Size: {batch_size}")
print(f"Image Size: {image_size}")
print(f"Learning Rate: {learning_rate}")
print(f"Alpha: {alpha}")
print(f"Beta: {beta}")
print(f"Temperature: {temperature}")
print(f"Total runtime: {total_runtime}")
print(f"Tensorflow Version: {tf.__version__}")
print(f"Keras Version: {keras.__version__}")

# Reset stdout to its original value
sys.stdout = sys.__stdout__

# Get the captured output
captured_output = output_capture.getvalue()


# Define the file path
file_path = f'/mnt/c/Users/Ann Clarisse Salazar/Documents/project/records_kd/{model_id}/info.txt'

# Ensure the directory exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)

# save to info.txt
with open(f"/mnt/c/Users/Ann Clarisse Salazar/Documents/project/records_kd/{model_id}/info.txt", 'w') as file:
    file.write(captured_output)