In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras import regularizers
import random
import copy
import packages.baseline_model as baselibe_model
import packages.common_functions as common_functions
from skimage import io
import cv2
import packages.trans_in_rgb as trans_in_rgb
import matplotlib.pyplot as plt
import packages.CL_model as CL_model
import time
from tqdm import tqdm
from jpeg2dct.numpy import load, loads

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
class RGBProjectionHead(tf.keras.Model):
    """
    projection head for contrastive learning
    """

    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(512, bias_initializer=tf.keras.initializers.constant(0.01),
                                            kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4),
                                            bias_regularizer=regularizers.l2(1e-4))
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.dense2 = tf.keras.layers.Dense(256, bias_initializer=tf.keras.initializers.constant(0.01),
                                            kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4),
                                            bias_regularizer=regularizers.l2(1e-4))
        self.bn2 = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=None):
        x = inputs[0] + inputs[1]
        x = self.dense1(x)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        #x = tf.concat((x,inputs[1]), 1)
        output = self.dense2(x)

        return output

In [None]:
class Resnet18_BaseEncoder(tf.keras.Model):

    def __init__(self, layer_params=None, method="late_concate"):
        super(Resnet18_BaseEncoder, self).__init__()

        self.method = method
        if layer_params is None:
            layer_params = [2, 2, 2, 2]
        self.y_input_bn = tf.keras.layers.BatchNormalization()
        self.cbcr_input_bn = tf.keras.layers.BatchNormalization()

        self.layer1_y = make_basic_block_layer(filter_num=64,
                                             blocks=1, dimen_match=True) # y:56,56,64
        self.cb2_y = make_basic_block_layer(filter_num=128,
                                             blocks=1, stride=2) # y:28,28,128
        self.cb2_cbcr = tf.keras.layers.Conv2D(filters=128,
                                               kernel_size=(1,1),
                                               strides=1,
                                               padding="same")
        self.cb_bn = tf.keras.layers.BatchNormalization()
        self.layer4 = make_basic_block_layer(filter_num=512,
                                             blocks=2, # 14,14,512
                                             stride=2)
        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.flatten = tf.keras.layers.Flatten()

    def call(self, inputs, training=None, mask=None):
        y = self.y_input_bn(inputs[0], training=training)
        cb_cr = self.cbcr_input_bn(inputs[1], training=training)

        y = self.layer1_y(y, training=training) # 56,56,64
        y = self.cb2_y(y, training=training) # 28,28,128
        cb_cr = tf.nn.relu(self.cb_bn(self.cb2_cbcr(cb_cr), training=training)) #28,28,128
        x = tf.concat((y, cb_cr), axis=3) #28,28,256
        x = self.layer4(x, training=training) #14,14,512
        x = self.avgpool(x)
        output = self.flatten(x)

        return output


def make_basic_block_layer(filter_num, blocks, k_size=(3,3), stride=1, dimen_match=False):
    res_block = tf.keras.Sequential()
    res_block.add(BasicBlock(filter_num, k_size=k_size, stride=stride, dimen_match=dimen_match))

    for _ in range(1, blocks):
        res_block.add(BasicBlock(filter_num, stride=1))

    return res_block


class BasicBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num, k_size=(3,3), stride=1, dimen_match=False):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=k_size,
                                            strides=stride,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=k_size,
                                            strides=1,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        if stride != 1 or dimen_match is True:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                       kernel_size=(1, 1),
                                                       strides=stride))
            self.downsample.add(tf.keras.layers.BatchNormalization())
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))

        return output

In [None]:
# Define hyper parameters
batch_size = 128
tau = 0.1
weights_array = np.ones((batch_size,), dtype=np.float32)

# Define base encoder and projection head
feature_extractor = Resnet18_BaseEncoder()
projection_head = RGBProjectionHead()

pic_num_list = [3000, 2993, 3000, 3000, 3000, 3000, 3000, 2999, 3000, 3000, 2999, 3000, 3000, 2994, 3000]
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)


# define batch for query and positive
query_batch_y = np.zeros((batch_size, 28, 28, 6))
query_batch_cbcr = np.zeros((batch_size, 14, 14, 6))
positive_batch_y = np.zeros((batch_size, 28, 28, 6))
positive_batch_cbcr = np.zeros((batch_size, 14, 14, 6))

error_list = []

subject_num = 10
error_list = np.zeros((subject_num,1))
error_list = error_list.tolist()

identity_representations_list = []
for sub in range(subject_num):
    identity_representation = tf.Variable(tf.zeros((1, 512)), trainable=True)
    identity_representations_list.append(identity_representation)

In [None]:
for epoch in tqdm(range(5000)):

    subject_index = np.random.choice(subject_num, subject_num, replace=False)
    for sub in subject_index:
        
        left_landmarks = np.load(
            "/home/walter/MPIIFaceGaze_png/" + str(sub) + "/left_landmarks.npy")
        right_landmarks = np.load(
            "/home/walter/MPIIFaceGaze_png/" + str(sub) + "/right_landmarks.npy")
        
        index = np.random.choice(pic_num_list[sub], batch_size, replace=False)
        j = 0
        
        for i in index:
            jpeg_file = '/home/walter/MPIIFaceGaze_png/' + str(sub) + '/' + str(i) + '.png'
            img = io.imread(jpeg_file) / 255.

            noised_image1, angle_1 = trans_in_rgb.data_transformation(copy.deepcopy(img), left_landmarks[i], right_landmarks[i])
            noised_image2, angle_2 = trans_in_rgb.data_transformation(copy.deepcopy(img), left_landmarks[i], right_landmarks[i])
            
            #weights_array[j] = (6 - np.abs(angle_1 - angle_2) / 10.)/4
            weights_array[j] = (3 - np.abs(angle_1 - angle_2) / 10.)/2
            #weights_array[j] = (2 - np.abs(angle_1 - angle_2) / 10.)/1.25

            noised_image1 = cv2.resize(noised_image1, (224, 224))
            noised_image2 = cv2.resize(noised_image2, (224, 224))

            noised_image1 *= 255
            noised_image1 = np.round(noised_image1).astype(np.uint8)
            noised_image2 *= 255
            noised_image2 = np.round(noised_image2).astype(np.uint8)
            io.imsave("/home/walter/CVPR23/MPIIFaceGaze/temp_pic/noised_image1.jpg", noised_image1)
            io.imsave("/home/walter/CVPR23/MPIIFaceGaze/temp_pic/noised_image2.jpg", noised_image2)

            jpeg_noised = "/home/walter/CVPR23/MPIIFaceGaze/temp_pic/noised_image1.jpg"
            dct_y, dct_cb, dct_cr = load(jpeg_noised)
            # channel selection
            dct_y = np.concatenate((np.concatenate((dct_y[:, :, 0:3], dct_y[:, :, 8:10]), axis=2),
                                    np.reshape(dct_y[:, :, 16], (28, 28, 1))), axis=2)

            dct_cb = np.concatenate((dct_cb[:, :, 0:2], np.reshape(dct_cb[:, :, 8], (14, 14, 1))), axis=2)
            dct_cr = np.concatenate((dct_cr[:, :, 0:2], np.reshape(dct_cr[:, :, 8], (14, 14, 1))), axis=2)
            cb_cr = np.concatenate([dct_cb, dct_cr], 2)
            query_batch_y[j, :, :, :] = dct_y
            query_batch_cbcr[j, :, :, :] = cb_cr

            jpeg_noised = "/home/walter/CVPR23/MPIIFaceGaze/temp_pic/noised_image2.jpg"
            dct_y, dct_cb, dct_cr = load(jpeg_noised)
            # channel selection
            dct_y = np.concatenate((np.concatenate((dct_y[:, :, 0:3], dct_y[:, :, 8:10]), axis=2),
                                    np.reshape(dct_y[:, :, 16], (28, 28, 1))), axis=2)

            dct_cb = np.concatenate((dct_cb[:, :, 0:2], np.reshape(dct_cb[:, :, 8], (14, 14, 1))), axis=2)
            dct_cr = np.concatenate((dct_cr[:, :, 0:2], np.reshape(dct_cr[:, :, 8], (14, 14, 1))), axis=2)
            cb_cr = np.concatenate([dct_cb, dct_cr], 2)
            positive_batch_y[j, :, :, :] = dct_y
            positive_batch_cbcr[j, :, :, :] = cb_cr

            j += 1

        with tf.GradientTape() as tape:

            x1_feature = tf.math.l2_normalize(
                projection_head([feature_extractor([query_batch_y, query_batch_cbcr], training=True), tf.tile(identity_representations_list[sub], [batch_size, 1])], training=True),axis=1)
            
            x2_feature = tf.math.l2_normalize(
                projection_head([feature_extractor([positive_batch_y, positive_batch_cbcr], training=True), tf.tile(identity_representations_list[sub], [batch_size, 1])], training=True),axis=1)

            x1_x2_mat = tf.exp(tf.matmul(x1_feature, tf.transpose(x2_feature)) / tau)

            denominator = tf.reduce_sum(x1_x2_mat, 1)
            # positive_sim = tf.linalg.diag_part(x1_x2_mat)
            prob = x1_x2_mat / tf.reshape(denominator, (x1_feature.shape[0], 1))
            prob = tf.linalg.diag_part(prob)
            weighted_prob = prob * weights_array

            loss = -tf.reduce_mean(tf.math.log(weighted_prob))

        grads = tape.gradient(loss, feature_extractor.trainable_variables + projection_head.trainable_variables+[identity_representations_list[sub]])
        optimizer.apply_gradients(
            grads_and_vars=zip(grads, feature_extractor.trainable_variables + projection_head.trainable_variables+[identity_representations_list[sub]]))
        error_list[sub].append(loss.numpy())

    print("current epoch: ", epoch, "current loss", loss.numpy())

In [None]:
error_list_arr= np.array(error_list)
for sub_fig in range(subject_num):
    plt.plot(error_list_arr[sub_fig][1:-1])

In [None]:
# feature_extractor.save_weights("/home/walter/CVPR23/MPIIFaceGaze/TOSN_rotation_multitask/base_1")
# projection_head.save_weights("/home/walter/CVPR23/MPIIFaceGaze/TOSN_rotation_multitask/pro_1")

In [None]:
def read_dct_coefficients_batch(local_index, subject_name):
    
    j = 0
    local_mini_batch_y = np.zeros((local_index.shape[0], 56, 56, 6))
    local_mini_batch_cbcr = np.zeros((local_index.shape[0], 28, 28, 6))
    for i in local_index:
        jpeg_file = '/home/walter/MPIIFaceGaze_normalized/' + subject_name + '/' + str(i) + '.jpg'
        dct_y, dct_cb, dct_cr = load(jpeg_file)
        # channel selection
        dct_y = np.concatenate((np.concatenate((dct_y[:, :, 0:3], dct_y[:, :, 8:10]), axis=2),
                                np.reshape(dct_y[:, :, 16], (56, 56, 1))), axis=2)

        dct_cb = np.concatenate((dct_cb[:, :, 0:2], np.reshape(dct_cb[:, :, 8], (28, 28, 1))), axis=2)
        dct_cr = np.concatenate((dct_cr[:, :, 0:2], np.reshape(dct_cr[:, :, 8], (28, 28, 1))), axis=2)
        cb_cr = np.concatenate([dct_cb, dct_cr], 2)
        #cb_cr = np.repeat(np.repeat(cb_cr, 2, 0), 2, 1)
        #img = cv2.resize(img, (224,224))
        local_mini_batch_y[j, :, :, :] = dct_y
        local_mini_batch_cbcr[j, :, :, :] = cb_cr
        j += 1
    return local_mini_batch_y, local_mini_batch_cbcr

In [None]:
def read_dct_coefficients_batch_224(local_index, subject_name):
    
    j = 0
    local_mini_batch_y = np.zeros((local_index.shape[0], 28, 28, 6))
    local_mini_batch_cbcr = np.zeros((local_index.shape[0], 14, 14, 6))
    for i in local_index:
        jpeg_file = '/home/walter/MPIIFaceGaze_normalized/' + subject_name + '/' + str(i) + '.jpg'
        img = io.imread(jpeg_file)
        resized_image = cv2.resize(img, (224, 224))
        io.imsave("/home/walter/CVPR23/MPIIFaceGaze/temp_pic/resized_image.jpg", resized_image)

        jpeg_file = "/home/walter/CVPR23/MPIIFaceGaze/temp_pic/resized_image.jpg"

        dct_y, dct_cb, dct_cr = load(jpeg_file)
        # channel selection
        dct_y = np.concatenate((np.concatenate((dct_y[:, :, 0:3], dct_y[:, :, 8:10]), axis=2),
                                np.reshape(dct_y[:, :, 16], (28, 28, 1))), axis=2)

        dct_cb = np.concatenate((dct_cb[:, :, 0:2], np.reshape(dct_cb[:, :, 8], (14, 14, 1))), axis=2)
        dct_cr = np.concatenate((dct_cr[:, :, 0:2], np.reshape(dct_cr[:, :, 8], (14, 14, 1))), axis=2)
        cb_cr = np.concatenate([dct_cb, dct_cr], 2)
        #cb_cr = np.repeat(np.repeat(cb_cr, 2, 0), 2, 1)
        #img = cv2.resize(img, (224,224))
        local_mini_batch_y[j, :, :, :] = dct_y
        local_mini_batch_cbcr[j, :, :, :] = cb_cr
        j += 1
    return local_mini_batch_y, local_mini_batch_cbcr

In [None]:
def train(subject_name):
    
    min_val = 100.
    for epoch in tqdm(range(500)):
        # optimization
        with tf.GradientTape() as tape:
            prediction = GazeEstimation(feature_extractor([train_mini_batch_y, train_mini_batch_cbcr]))
            loss = tf.reduce_mean(tf.square(labels_per[train_index] - prediction))

        grads = tape.gradient(loss, GazeEstimation.trainable_variables + feature_extractor.trainable_variables)
        optimizer_per.apply_gradients(grads_and_vars=zip(grads, GazeEstimation.trainable_variables + feature_extractor.trainable_variables))

        # training error

        train_angle_error = common_functions.avg_angle_error(prediction, labels_per[train_index])
        train_error_per.append(train_angle_error.numpy())

        # validation error

        prediction = GazeEstimation(feature_extractor([val_mini_batch_y, val_mini_batch_cbcr]))
        validation_angle_error = common_functions.avg_angle_error(prediction, labels_per[val_index])

        val_error_per.append(validation_angle_error.numpy())
        if validation_angle_error < min_val:
            min_val = validation_angle_error
            feature_extractor.save_weights('/home/walter/CL_gaze_project/MPIIFaceGaze/experimental_model/' + subject_name + '/dct_ours/base')
            GazeEstimation.save_weights('/home/walter/CL_gaze_project/MPIIFaceGaze/experimental_model/' + subject_name + '/dct_ours/gaze')

In [None]:
def test(subject_name):
    
    feature_extractor.load_weights('/home/walter/CL_gaze_project/MPIIFaceGaze/experimental_model/' + subject_name + '/dct_ours/base')
    GazeEstimation.load_weights('/home/walter/CL_gaze_project/MPIIFaceGaze/experimental_model/' + subject_name + '/dct_ours/gaze')
    test_error = 0.
    for batch_i in range(batch_number):
        start = batch_i * 100
        end = (batch_i + 1) * 100
        if end >= test_index.shape[0]:
            end = test_index.shape[0]
        test_batch_y = np.zeros((test_index[start: end].shape[0], 56, 56, 6))
        test_batch_cbcr = np.zeros((test_index[start: end].shape[0], 28, 28, 6))
        j = 0
        for i in test_index[start: end]:
            jpeg_file = '/home/walter/MPIIFaceGaze_normalized/' + subject_name + '/' + str(i) + '.jpg'
            dct_y, dct_cb, dct_cr = load(jpeg_file)
            # channel selection
            dct_y = np.concatenate((np.concatenate((dct_y[:, :, 0:3], dct_y[:, :, 8:10]), axis=2),
                                    np.reshape(dct_y[:, :, 16], (56, 56, 1))), axis=2)

            dct_cb = np.concatenate((dct_cb[:, :, 0:2], np.reshape(dct_cb[:, :, 8], (28, 28, 1))), axis=2)
            dct_cr = np.concatenate((dct_cr[:, :, 0:2], np.reshape(dct_cr[:, :, 8], (28, 28, 1))), axis=2)
            cb_cr = np.concatenate([dct_cb, dct_cr], 2)

            test_batch_y[j, :, :, :] = dct_y
            test_batch_cbcr[j, :, :, :] = cb_cr
            j += 1
        prediction = GazeEstimation(feature_extractor([test_batch_y, test_batch_cbcr]))
        test_error += common_functions.avg_angle_error(prediction, labels_per[test_index[start: end]]) \
                      * test_index[start: end].shape[0]
    return test_error.numpy() / test_index.shape[0]

In [None]:
per_size = 100
print("current size for personalization is", per_size-25)
name_list = ['p10', 'p11', 'p12', 'p13', 'p14']
#name_list = ['p14']
avg_list = []
for name in name_list:
    
    avg_error_list = []
    labels_per = np.load("/home/walter/MPIIFaceGaze_normalized/" + name + "/labels.npy")
    data_size = labels_per.shape[0]
    eposide = 6
    print(name, " start to train.")
    total_index = np.load("/home/walter/MPIIFaceGaze_normalized/" + name + "/random_index.npy")

    for i in range(eposide):
        feature_extractor.load_weights("/home/walter/CVPR23/MPIIFaceGaze/TOSN_rotation_multitask/base_1")
        GazeEstimation = CL_model.GazeEstimationHead()

        optimizer_per = tf.keras.optimizers.Adam(learning_rate=.001)

        train_error_per = []
        val_error_per = []

        train_start = int(per_size*i)
        train_end = train_start + per_size - 25

        train_index = total_index[train_start: train_end]
        val_index = total_index[train_end: train_end + 25]

        # read training data
        train_mini_batch_y, train_mini_batch_cbcr = read_dct_coefficients_batch(train_index, name)
        val_mini_batch_y, val_mini_batch_cbcr = read_dct_coefficients_batch(val_index, name)

        if i == 0:
            test_index =  total_index[train_end + 25: data_size]
        else:
            test_index =  np.concatenate((total_index[0: train_start], total_index[train_end + 25: data_size]), 0)


        if test_index.shape[0] % 100 != 0:
            batch_number = int(test_index.shape[0]/100) + 1
        else:
            batch_number = int(test_index.shape[0]/100)

        print(" current eposide: ", i+1, " start to train from ", train_start, "to ", train_end, 
              "train samples:", train_index.shape[0], "val samples:", val_index.shape[0], "test samples:", test_index.shape[0])
        train(name)
        test_error = test(name)
        avg_error_list.append(test_error)
        val_error_per = np.array(val_error_per)
        plt.ylim(2,7)

        plt.plot(val_error_per,label='val')
        plt.legend()
        plt.show()

        min_val_index = np.argmin(val_error_per)
        print(name, ":", train_start, "to ", train_end, " achieve best accuarcy ", test_error," at the ", min_val_index, "th iteration")
        print(" ")
        print(" ")
        print("--------------------------------------------------------------------------------------------")
    print(name, "average angular error over six trails", np.mean(np.array(avg_error_list)))
    avg_list.append(np.mean(np.array(avg_error_list)))