# For days 4-5: TensorProjection layer, model reduction and fitting stability

In [1]:
import os;
# os.environ['KMP_DUPLICATE_LIB_OK']='True';

In [2]:
import tensorflow as tf;
import numpy as np;
import pandas as pd;
import matplotlib.pyplot as plt;

In [3]:
class TensorProjectionLayer(tf.keras.layers.Layer):
    def __init__(self, q1,q2,q3, regularization='None', rate=10**-3,**kwargs):
        self.q1 = int(q1);
        self.q2 = int(q2);
        self.q3 = int(q3);
        self.regularization = regularization;
        self.rate = rate; # regularization coefficient
        super(TensorProjectionLayer, self).__init__(**kwargs);
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.q1, self.q2, self.q3);
    
    def get_config(self):
        base_config = super(TensorProjectionLayer, self).get_config();
        base_config['q1'] = self.q1;
        base_config['q2'] = self.q2;
        base_config['q3'] = self.q3;
        base_config['regularization'] = self.regularization;
        base_config['rate'] = self.rate;
        #base_config['output_dim'] = (self.q1, self.q2, self.q3);
        return base_config;
    
    def build(self, input_shape):
        self.p1 = int(input_shape[1]);
        self.p2 = int(input_shape[2]);
        self.p3 = int(input_shape[3]);
        
        if(self.q1 < self.p1):
            self.W1 = self.add_weight("W1",shape=(self.p1, self.q1),initializer='normal', trainable=True);
        if(self.q2 < self.p2):
            self.W2 = self.add_weight("W2",shape=(self.p2, self.q2),initializer='normal', trainable=True);
        if(self.q3 < self.p3):
            self.W3 = self.add_weight("W3",shape=(self.p3, self.q3),initializer='normal', trainable=True);
            
        super(TensorProjectionLayer, self).build(input_shape);

    # suppose that T: n x t1 x t2 x t3
    # k=1,2,3
    def kmode_product(self, T, A, k):

        # number of the observations
        n  = tf.shape(T)[0];
        
        # Compute T xk A
        A = tf.expand_dims(A, 0);
        An =  tf.tile(A, [n, 1, 1]);
        if k == 1:
            return tf.einsum('npqr, nsp -> nsqr', T, An);
        elif k ==2:
            return tf.einsum('npqr, nsq -> npsr', T, An);
        elif k ==3:
            return tf.einsum('npqr, nsr -> npqs', T, An);
        
    def WtoU(self, W):
        e = 10**-6;
        q = tf.shape(W)[1];
        Iq = tf.eye(q);
        WT = tf.transpose(W, perm=[1,0]);
        M = tf.math.add(tf.linalg.matmul(WT, W), Iq * e);
        sqrtM = tf.linalg.sqrtm(M);
        G = tf.linalg.inv(sqrtM);
        U = tf.linalg.matmul(W, G);
        return U;
    
    
    def call(self, X):
        
        Z = X;        
        
        if self.q1 < self.p1:
            U1 = self.WtoU(self.W1);
            U1T = tf.transpose(U1, perm=[1,0]); # q1 x p1
            Z = self.kmode_product(Z, U1T, 1);
        if self.q2 < self.p2:
            U2 = self.WtoU(self.W2);
            U2T = tf.transpose(U2, perm=[1,0]); # q1 x p1
            Z = self.kmode_product(Z, U2T, 2);
        if self.q3 < self.p3:
            U3 = self.WtoU(self.W3);
            U3T = tf.transpose(U3, perm=[1,0]); # q1 x p1
            Z = self.kmode_product(Z, U3T, 3);
        
        # compute reconstruction error
        if self.regularization == 'reconstruction_error':
            X_ = Z;
            if self.q1 < self.p1:
                X_ = self.kmode_product(X_, U1, 1);
            if self.q2 < self.p2:
                X_ = self.kmode_product(X_, U2, 2);
            if self.q3 < self.p3:
                X_ = self.kmode_product(X_, U3, 3);
            #X_ = tf.reshape(X_, [n,self.p1, self.p2, self.p3]);
            dn2 = tf.math.squared_difference(X , X_); # n, p1,p2,p3
            dn2 = tf.math.reduce_mean(dn2,axis=1); # n, p2,p3
            dn2 = tf.math.reduce_mean(dn2,axis=1); # n, p3
            dn2 = tf.math.reduce_mean(dn2,axis=1); # n
            dn = tf.math.pow(dn2, 0.5);
            self.add_loss(self.rate *tf.math.reduce_mean(dn), True);
        elif self.regularization == 'total_variation':
            mz = tf.reduce_mean(Z, axis=0,keepdims = True);
            mz = tf.tile(mz,[n,1,1,1]);
            Z_ = Z - mz; # centerize
            v = tf.math.pow(Z_,2);
            v = tf.reduce_mean(v,axis=1);
            v = tf.reduce_mean(v,axis=1);
            v = tf.reduce_mean(v,axis=1);
            v = tf.math.pow(v,0.5);
            self.add_loss(self.rate *tf.math.reduce_mean(v), True);
            
        return Z;

In [4]:
# load data
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    width_shift_range=0.05,
    height_shift_range=0.15,
    rotation_range=10,
    zoom_range=0.1,
)
validation_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    './data/chest_xray/train',
    #'/users/syhuangmac/Dropbox/chest-x-ray/data/train',
    target_size=(224,224),
    batch_size=100,
    class_mode="binary",
    #color_mode="grayscale",
    shuffle=True,
)
 
validation_generator = validation_datagen.flow_from_directory(
    './data/chest_xray/val',
   # '/users/syhuangmac/Dropbox/chest-x-ray/data/val',
    target_size=(224,224),
    class_mode="binary",
    #color_mode="grayscale",
    shuffle=False,
    batch_size=1,
)
test_generator = test_datagen.flow_from_directory(
    './data/chest_xray/test',
    #'/users/syhuangmac/Dropbox/chest-x-ray/data/test',
    target_size=(224,224),
    class_mode="binary",
    #color_mode="grayscale",
    shuffle=False,
    batch_size=1,
)
# num classes
num_classes = 2;

Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


In [5]:
# Directory Number
i = 1;
while(1):
    if os.path.isdir('./{}' . format(i)) == False:
        break;
    else:
        i = i + 1;
savedir = './{}' . format(i);
os.mkdir(savedir);

In [6]:
### TensorProjection Layer ###
model = tf.keras.Sequential()
pmodel=tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,input_tensor=tf.keras.layers.Input(shape=(224,224,3)));
pmodel.trainable = False;
model.add(pmodel);
model.add(TensorProjectionLayer(4,4,16));
model.add(tf.keras.layers.Flatten());
#model.add(tf.keras.layers.Dense(256));
model.add(tf.keras.layers.Dropout(0.5));
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.RMSprop(), metrics=[tf.keras.metrics.binary_crossentropy,'accuracy'])

# fit
history = model.fit(train_generator, epochs=5, validation_data=test_generator)
score = model.evaluate(test_generator, verbose=0)
print('Test loss:', score[0]);
print('Test accuracy:', score[2]);
model.save_weights(savedir + '/tensor-projection-layer_weights.h5');

# load from trained model
#model.load_weights('tensor-projection-layer_weights.h5');

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Test loss: 0.2985365092754364
Test accuracy: 0.8814102411270142


In [None]:
score # 0: total loss = cross_entropy + regularization, 1: cross-entropy, 2: accuracy

In [None]:
# save history
history_json = pd.DataFrame(history.history);
with open(savedir + '/history.json', 'w') as f:
    history_json.to_json(f);

# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(savedir + '/accuracy.png');
plt.clf();

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(savedir + '/loss.png');
plt.clf();

# Plot training & validation loss without regularization term
plt.plot(history.history['binary_crossentropy'])
plt.plot(history.history['val_binary_crossentropy'])
plt.title('Model loss without regularization term')
plt.ylabel('Loss without Regularization Term')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(savedir + '/loss-without-regularization.png');
plt.clf();

# Plot reconstruction error
plt.plot(np.array(history.history['loss'])-np.array(history.history['binary_crossentropy']))
plt.plot(np.array(history.history['val_loss'])-np.array(history.history['val_binary_crossentropy']))
plt.title('Penalty based on reconstruction error')
plt.ylabel('Penality')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(savedir + '/regularization.png');
plt.clf();


In [None]:
model.summary();