diff --git a/.gitignore b/.gitignore index 41b74da..e15584f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ *.jl.*.cov *.jl.mem data/ +*.jld +*.pyc +.vscode/ diff --git a/benchmarks/CIFAR10/profile_cifar10_512_64.jl b/benchmarks/CIFAR10/profile_cifar10_512_64.jl index 2a1dd1d..553e42f 100644 --- a/benchmarks/CIFAR10/profile_cifar10_512_64.jl +++ b/benchmarks/CIFAR10/profile_cifar10_512_64.jl @@ -1,10 +1,10 @@ using MAT, Meganet, BenchmarkTools, Compat, JLD, ProfileView # Macro Benchmark on CIFAR10 -n = 32 +n = 64 miniBatchSize = 32 -path2data = "/home/klensink/Documents/cifar-10-batches-mat/" +path2data = Pkg.dir("Meganet")*"/data/CIFAR10/" history = Pkg.dir("Meganet")*"/benchmarks/CIFAR10/cifar10_512_64.jld" Y_train,C_train,Y_test,C_test = getCIFAR10(n, path2data) @@ -80,7 +80,3 @@ Profile.clear() Profile.init(100000000, 0.001) @profile solve(opt,objFun::dnnObjFctn,[vec(theta);vec(W)],Y_train,C_train,Y_test,C_test) ProfileView.view() - -if true - Juno.profiler() -end diff --git a/benchmarks/CIFAR10/tf_benchmarks/Resnet.py b/benchmarks/CIFAR10/tf_benchmarks/Resnet.py new file mode 100644 index 0000000..17b8dac --- /dev/null +++ b/benchmarks/CIFAR10/tf_benchmarks/Resnet.py @@ -0,0 +1,284 @@ +"""Contains resnet model +""" + +import tensorflow as tf +import numpy as np +import time +import math +from utils_input import * +import os + + +class ResnetModel(object): + + def __init__(self, config): + self.channels = config['channels'] + self.k_size = config['kSize'] + self.batchsize = config['batchSize'] + self.epochs = config['maxEpochs'] + self.h = config['h'] + self.num_units = config['numUnits'] + self.num_blocks = config['numBlocksPerUnit'] + self.xtrain = config['xTrain'] + self.xval = config['xValid'] + self.ytrain = config['yTrain'] + self.yval = config['yValid'] + + self.initial_channels = 3 + self.num_train = self.ytrain.shape[0] + self.num_val = self.yval.shape[0] + self.n_class = 10 + + self.strides = [1,1,1,1] + self.activation_func = tf.tanh + self.epochs_completed = 0 + self.weight_decay_rate = .0002 + self.lr = .01 + self.momentum = .9 + self.shuffle = True + + def train(self): + with tf.device("/cpu:0"): + self.X = tf.placeholder(tf.float32, [None, self.xtrain.shape[1]]) + self.y = tf.placeholder(tf.float32, [None, self.n_class]) + + self.logits = self.build_model() + self.loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.logits)) + self.loss += weight_decay(self.weight_decay_rate, keyword=r'K_') + + self.train_op = tf.train.MomentumOptimizer(self.lr, self.momentum).minimize(self.loss) + + # Create init op + init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) + + # Only run on single thread + session_conf = tf.ConfigProto( + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) + + with tf.Session(config=session_conf) as self.sess: + + # Initialize variables + self.sess.run(init_op) + + self.train_steps_per_epoch = int(math.ceil(self.num_train / float(self.batchsize))) + self.test_steps_per_epoch = self.num_val // self.batchsize + + # Time it! + start_time = time.time() + + train_index = 0 + for ep in range(self.epochs): + + for s in range(self.train_steps_per_epoch): + train_batch, train_index = self.next_train_batch(train_index) + train_data = train_batch[0] + train_labels = train_batch[1] + _, loss = self.sess.run([self.train_op, self.loss], feed_dict={self.X: train_data, self.y: train_labels}) + print(s) + print(loss) + + self.validate_model() + self.epochs_completed += 1 + + end_time = time.time() + print("Total time: %f" % (end_time - start_time)) + + def validate_model(self): + # Run training + avg_train = 0 + n_ex = min(2**12, self.num_train) + n_steps = int(math.ceil(n_ex / float(self.batchsize))) + train_index = 0 + for s in range(n_steps): + train_batch, train_index = self.next_train_batch(train_index) + train_data = train_batch[0] + train_labels = train_batch[1] + _, loss = self.sess.run([self.train_op, self.loss], feed_dict={self.X: train_data, self.y: train_labels}) + avg_train += loss + avg_train /= n_steps + print("overall train:") + print(avg_train) + + # Run Validation + val_index = 0 + avg_loss = 0 + for i in range(self.test_steps_per_epoch): + val_batch, val_index = self.next_val_batch(val_index) + val_data = val_batch[0] + val_labels = val_batch[1] + + preds, loss = self.sess.run([self.logits, self.loss], feed_dict={self.X: val_data, + self.y: val_labels}) + + avg_loss += loss + + avg_loss /= self.test_steps_per_epoch + print("val: %f" % avg_loss) + + def build_model(self): + with tf.device("/cpu:0"): + # Reshape + X = tf.reshape(self.X, [-1, 32, 32, self.initial_channels]) + # Opening Convolution + K_opening = weight_variable([self.k_size, self.k_size, self.initial_channels, self.channels[0]], 'K_Opening') + X = tf.nn.conv2d(X, K_opening, strides=self.strides, padding='SAME') + + # Opening Batch Norm + X = tf.contrib.layers.batch_norm(X, fused=True, is_training=True, scope='Opening') + + #Opening Activation + X = self.activation_func(X) + + # Build Units + for unit_num in range(self.num_units): + X = self.unit(X, unit_num) + + # Average Pooling Layer + X = tf.reduce_mean(X, [1,2]) + X = tf.reshape(X, [self.batchsize, -1]) + + # Fully Connected Layer + K_FC = weight_variable([self.channels[-1], self.n_class], 'K_FC') + B_FC = bias_variable([self.n_class], 'B_FC') + + # Return Logits + logits = tf.nn.xw_plus_b(X, K_FC, B_FC) + return logits + + def unit(self, X, unit_num): + with tf.variable_scope("unit_%d" % unit_num): + for block_num in range(self.num_blocks[unit_num] - 1): + X = self.block(X, unit_num, block_num) + + # Create Connector Block + X = self.connector_block(X, unit_num) + + return X + + def block(self, X0, unit_num, block_num): + with tf.variable_scope("block_%d" % block_num): + # First Convolution + K1 = weight_variable([self.k_size, self.k_size, self.channels[unit_num], self.channels[unit_num]], 'K_1') + X1 = tf.nn.conv2d(X0, K1, strides=self.strides, padding='SAME') + + # Batch Norm + X1 = tf.contrib.layers.batch_norm(X1, fused=True, is_training=True, scope='bn_1') + + # Activation + X1 = self.activation_func(X1) + + # Second Convolution + X1 = -tf.nn.conv2d_transpose(X1, K1, + output_shape=tf.concat([tf.shape(X1)[0:3], [tf.shape(K1)[2]]], axis=0), + strides=self.strides, + padding='SAME') + X1 = tf.reshape(X1, tf.shape(X0)) + + # Second Batch Norm + X1 = tf.contrib.layers.batch_norm(X1, fused=True, is_training=True, scope='bn_2') + + # Compute return value of resnet block + X = X0 + self.h[unit_num] * X1 + + return X + + def connector_block(self, X, unit_num): + with tf.variable_scope("connector_block"): + # Convolution increasing channels + K_Conn = weight_variable([1, 1, self.channels[unit_num], self.channels[unit_num + 1]], 'K_Conn') + X = tf.nn.conv2d(X, K_Conn, strides=self.strides, padding='SAME') + + # Batch Norm + X = tf.contrib.layers.batch_norm(X, fused=True, is_training=True, scope='connector') + + # Activation + X = self.activation_func(X) + + X = tf.nn.avg_pool(X, ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], padding='VALID') + + return X + + def next_train_batch(self, index): + """Gets the next batch of data for training + + Arguments: + index {int} -- Current index in the training data + + Returns: + batch {tuple of numpy arrays} -- First np array is the batch of data and second is batch of labels + new_index {int} -- The new index in the training data after creating batch + """ + + start_index = index + end_index = index + self.batchsize + final_index = self.num_train - 1 + + # First batch + if self.shuffle and self.epochs_completed == 0 and start_index == 0: + self.shuffle_data() + + # If we are reaching the end of an epoch + if final_index < end_index: + + # Get remaining data in epoch + remainder = end_index - final_index + x_remaining = self.xtrain[start_index:final_index] + y_remaining = self.ytrain[start_index:final_index] + + if self.shuffle: + self.shuffle_data() + + # Get new data in next epoch + x_extra = self.xtrain[0:remainder] + y_extra = self.ytrain[0:remainder] + + # Combine data into single batch + x_batch = np.concatenate((x_remaining, x_extra), axis=0) + y_batch = np.concatenate((y_remaining, y_extra), axis=0) + batch = (x_batch, y_batch) + + # # Increment epochs_completed + # self.epochs_completed += 1 + + new_index = remainder + else: + batch = (self.xtrain[start_index:end_index], self.ytrain[start_index:end_index]) + new_index = end_index + + return batch, new_index + + def shuffle_data(self): + """Shuffles the training data + """ + perm = np.arange(self.num_train) + np.random.shuffle(perm) + self.xtrain = self.xtrain[perm] + self.ytrain = self.ytrain[perm] + + def next_val_batch(self, index): + """Gets the next batch of data for validation + + Arguments: + index {int} -- Current index in the validation data + + Returns: + batch {tuple of numpy arrays} -- First np array is the batch of data and second is batch of labels + new_index {int} -- The new index in the validation data after creating batch + """ + if self.epochs_completed == 0 and index == 0: + # Shuffle validation data at start + perm = np.arange(self.num_val) + np.random.shuffle(perm) + self.xval = self.xval[perm] + self.yval = self.yval[perm] + #TODO: Ensure we are validating on all of the validation data + start_index = index + end_index = index + self.batchsize + + batch = (self.xval[start_index:end_index], self.yval[start_index:end_index]) + new_index = end_index + + return batch, new_index \ No newline at end of file diff --git a/benchmarks/CIFAR10/tf_benchmarks/cifar10_512_64.py b/benchmarks/CIFAR10/tf_benchmarks/cifar10_512_64.py new file mode 100644 index 0000000..9b73bc2 --- /dev/null +++ b/benchmarks/CIFAR10/tf_benchmarks/cifar10_512_64.py @@ -0,0 +1,32 @@ +"""Runs benchmark using tensorflow and 1 cpu thread + Runs 1 epoch with 512/103 Train/Val split in 8.56s +""" +from utils_input import * +from Resnet import ResnetModel + +DATA_DIR = "path/to/where/you/want/your/data/stored" + +if __name__ == '__main__': + # Load Data + ntrain = 512 + nval = 103 + train_data, valid_data = load_cifar10(DATA_DIR, ntrain, nval) # Note this does not evenly distribute classes + + # Set config + modelConfig = { + "channels": [16,32,64,64], + "batchSize": 64, + "numUnits": 3, + "numBlocksPerUnit": [3,3,3], + "h": [1.,1.,1.], + "maxEpochs": 1, + "kSize": 3, + "xTrain": train_data[0], + "xValid": valid_data[0], + "yTrain": train_data[1], + "yValid": valid_data[1], + } + + # Train model + model = ResnetModel(modelConfig) + model.train() diff --git a/benchmarks/CIFAR10/tf_benchmarks/utils_input.py b/benchmarks/CIFAR10/tf_benchmarks/utils_input.py new file mode 100644 index 0000000..6bde335 --- /dev/null +++ b/benchmarks/CIFAR10/tf_benchmarks/utils_input.py @@ -0,0 +1,145 @@ +"""Contains helper functions for running tensorflow benchmark +""" + +import numpy as np +import os +import sys +from six.moves import urllib +from six.moves import cPickle +import tarfile +import tensorflow as tf + +def one_hot(labels, n_class): + """ Return one hot encoding of labels + + Args: + labels: a vector of length n, with values in [0, n_class-1] + n_class: number of classes, typically 10 + Returns: + One hot encoding, a (n, n_class) np array. + """ + + return np.eye(n_class)[labels].reshape((labels.shape[0], n_class)) + +def maybe_download_and_extract(DATA_URL, dest_directory, extracted_filepath=None): + """Download and extract the tarball from Alex's website.""" + # https://github.com/tensorflow/models/blob/dac6755b121f1446ec857cd05c2ff53b2fd26b90/tutorials/image/cifar10/cifar10.py + + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') + if extracted_filepath: + extracted_dir_path = os.path.join(dest_directory, extracted_filepath) + if not os.path.exists(extracted_dir_path): + tarfile.open(filepath, 'r:gz').extractall(dest_directory) + +def load_batch(fpath, label_key='labels'): + """Internal utility for parsing CIFAR data. + # Arguments + fpath: path the file to parse. + label_key: key for label data in the retrieve + dictionary. + # Returns + A tuple `(data, labels)`. + """ + # https://github.com/fchollet/keras/blob/3e933ca0ed1c526c0a9b8643ca84129db96ecc17/keras/datasets/cifar.py + + f = open(fpath, 'rb') + if sys.version_info < (3,): + d = cPickle.load(f) + else: + d = cPickle.load(f, encoding='bytes') + # decode utf8 + d_decoded = {} + for k, v in d.items(): + d_decoded[k.decode('utf8')] = v + d = d_decoded + f.close() + data = d['data'] + labels = d[label_key] + + data = data.reshape(data.shape[0], 3, 32, 32).transpose([0, 2, 3, 1]) + return data, labels + +def load_cifar10(dirname, ntrain, nval): + """Loads CIFAR10 dataset. + # Returns + Tuple of np arrays: `(x_train, y_train), (x_test, y_test)`. + """ + CIFAR10_DATA_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + + maybe_download_and_extract( + CIFAR10_DATA_URL, dirname, extracted_filepath='cifar-10-batches-py') + path = os.path.join(dirname, 'cifar-10-batches-py') + + nb_train_samples = 50000 + + x_train = np.zeros((nb_train_samples, 32, 32, 3), dtype='uint8') + y_train = np.zeros((nb_train_samples,), dtype='uint8') + + for i in range(1, 6): + fpath = os.path.join(path, 'data_batch_' + str(i)) + data, labels = load_batch(fpath) + x_train[(i - 1) * 10000: i * 10000, :, :, :] = data + y_train[(i - 1) * 10000: i * 10000] = labels + + fpath = os.path.join(path, 'test_batch') + x_test, y_test = load_batch(fpath) + + y_train = np.reshape(y_train, (len(y_train), 1)) + y_test = np.reshape(y_test, (len(y_test), 1)) + + x_train = x_train.reshape((x_train.shape[0], 32 * 32 * 3)) + x_test = x_test.reshape((x_test.shape[0], 32 * 32 * 3)) + + y_train_one_hot = one_hot(y_train, 10) + y_test_one_hot = one_hot(y_test, 10) + + return (x_train[:ntrain], y_train_one_hot[:ntrain]), (x_test[:nval], y_test_one_hot[:nval]) + +def weight_decay(weight_decay_rate, keyword): + """Adds weight decay using l2 loss to all variables with names matching keyword + + Arguments: + weight_decay_rate {float} -- The decay rate factor + keyword {string or regex} -- The keyword to match variable names to + + Returns: + [Tensor] -- A Tensor containing the total loss from weight decay + """ + costs = [] + for var in tf.trainable_variables(): + if var.op.name.find(keyword) > 0: + costs.append(tf.nn.l2_loss(var) / 2) + + if len(costs) == 0: + # No variables match the keyword + return 0 + else: + return tf.multiply(weight_decay_rate, tf.add_n(costs)) + +def weight_variable(shape, name): + """Creates a tensorflow variable with given shape and name initialized using Xavier Initialization + Arguments: + shape {array} -- Array with the dimensions needed for the variable + name {string} -- Name for the variable + + Returns: + [tf.Variable] -- A trainable tensorflow variable of type float32 + """ + return tf.get_variable(name, shape=shape, dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) + +def bias_variable(shape, name): + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.1, shape=shape) + return tf.get_variable(name, dtype=tf.float32, initializer=initial) \ No newline at end of file diff --git a/benchmarks/micro/bm_batchnorm.jl b/benchmarks/micro/bm_batchnorm.jl index 1c044e2..e38f39d 100644 --- a/benchmarks/micro/bm_batchnorm.jl +++ b/benchmarks/micro/bm_batchnorm.jl @@ -1,23 +1,55 @@ using Meganet, BenchmarkTools -history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_batchnorm.jld" +const history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_batchnorm.jld" -TYPE = Float64 +const TYPE = Float64 npixel = 500 nex = 1000 nchannel = 3 -L = getNormLayer(TYPE,[npixel,nchannel,nex],3) -theta = initTheta(L) -Y = randn(TYPE,nFeatIn(L),nex) +const L = getNormLayer(TYPE,[npixel,nchannel,nex],3) +const theta = initTheta(L) +const Y = randn(TYPE,nFeatIn(L),nex) -Yout2,Yout2,tmp2 = apply(L,theta,Y,true) -@code_warntype apply(L,theta,Y,true) +function benchmarkJYTmv() + funcName = "JYTmv" #TODO: pass funcName to history instead of calling it "hist" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) -trial = @benchmark apply(L,theta,Y,true); + #Warmup + Z1 = JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + trial = @benchmark JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, "hist") + hist = JLD.load(history, "hist") + judge(hist) +end -Meganet.updatehistory!(history, trial) -hist = JLD.load(history, "hist") -judge(hist) \ No newline at end of file + +function benchmarkJYmv() + funcName = "JYmv" + _,_,tmp = apply(L,theta,Y,true) + dY = randn(TYPE,nFeatIn(L),nex) + #Warmup + dZ, dZ = JYmv(L,dY,theta,Y,tmp) + # @code_warntype JYmv(L, dY, theta, Y, tmp) + trial = @benchmark JYmv(L, dY, theta, Y, tmp) + Meganet.updatehistory!(history, trial, "hist") + hist = JLD.load(history, "hist") + judge(hist) +end + + +function benchmarkApply(L, theta, Y, history) + funcName = "apply" + Yout2,Yout2,tmp2 = apply(L,theta,Y,true) + + @code_warntype apply(L,theta,Y,true) + + trial = @benchmark apply(L,theta,Y,true) + + Meganet.updatehistory!(history, trial, "hist") + hist = JLD.load(history, "hist") + judge(hist) +end diff --git a/benchmarks/micro/bm_connector.jl b/benchmarks/micro/bm_connector.jl new file mode 100644 index 0000000..05e0022 --- /dev/null +++ b/benchmarks/micro/bm_connector.jl @@ -0,0 +1,26 @@ +using Meganet, BenchmarkTools + +history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_connector.jld" + +TYPE = Float64 + +npixel = 500 +nex = 1000 + +K = randn(TYPE,10,5) +L = getConnector(TYPE,K,outTimes=1) + +theta = initTheta(L) + +Y = randn(TYPE,nFeatIn(L),nex) + +# Warmup +Yout,Yout,tmp = apply(L,theta,Y,true) + +@code_warntype apply(L,theta,Y,true) + +trial = @benchmark apply(L,theta,Y,true); + +Meganet.updatehistory!(history, trial) +hist = JLD.load(history, "hist") +judge(hist) \ No newline at end of file diff --git a/benchmarks/micro/bm_connector.jld b/benchmarks/micro/bm_connector.jld new file mode 100644 index 0000000..2942fe8 Binary files /dev/null and b/benchmarks/micro/bm_connector.jld differ diff --git a/benchmarks/micro/bm_doubleSymLayer.jld b/benchmarks/micro/bm_doubleSymLayer.jld index 0531a14..894730d 100644 Binary files a/benchmarks/micro/bm_doubleSymLayer.jld and b/benchmarks/micro/bm_doubleSymLayer.jld differ diff --git a/benchmarks/micro/bm_multConv2Dblock.jl b/benchmarks/micro/bm_multConv2Dblock.jl new file mode 100644 index 0000000..8044997 --- /dev/null +++ b/benchmarks/micro/bm_multConv2Dblock.jl @@ -0,0 +1,28 @@ +using Meganet, JLD, BenchmarkTools + +history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_multConv2Dblock.jld" +file = Pkg.dir("Meganet")*"/benchmarks/micro/vars_multConv2Dblock.jld" + +x = load(file, "x") +K = load(file, "K") +y = load(file, "y") +t = load(file, "t") +shiftX = load(file, "shiftX") +shiftT = load(file, "shiftT") +imIdx = load(file, "imIdx") +doDerivative = load(file, "doDerivative") + +trial = @benchmark Meganet.multConv2Dblock(x, K, y, t, shiftX, shiftT, imIdx, doDerivative = doDerivative) +@enter Meganet.multConv2Dblock(x, K, y, t, shiftX, shiftT, imIdx, doDerivative = doDerivative) +display(trial) +hist = load(history, "multConv2Dblock") + +if false + Meganet.updatehistory!(history, trial, "multConv2Dblock") +end + +y = load(file, "y") +Meganet.multConv2Dblock(x, K, y, t, shiftX, shiftT, imIdx, doDerivative = doDerivative) +y = load(file, "y") +Profile.clear_malloc_data() +Meganet.multConv2Dblock(x, K, y, t, shiftX, shiftT, imIdx, doDerivative = doDerivative); diff --git a/benchmarks/micro/bm_resNN.jl b/benchmarks/micro/bm_resNN.jl new file mode 100644 index 0000000..ab1f2e0 --- /dev/null +++ b/benchmarks/micro/bm_resNN.jl @@ -0,0 +1,94 @@ +using Meganet, BenchmarkTools + +const history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_resNN.jld" + +TYPE = Float32 +K = getConvGEMMKernel(TYPE,[96,96],[3,3,64,64]) +nex = 20 +# Bin = randn(TYPE,18,4) +# Bout = randn(TYPE,18,3) +nLayer = getBatchNormLayer(TYPE,[96*96,64],isTrainable=true) +lay = getSingleLayer(TYPE,K,nLayer) + +L = getResNN(TYPE,lay,4) + +theta = initTheta(L) + +Y = randn(TYPE,nFeatIn(L),nex) + +function benchmarkApply() + funcName = "apply" + Zd,Z,tmp = apply(L,theta,Y,true); + @code_warntype apply(L,theta,Y,true) + + trial = @benchmark apply(L,theta,Y,true) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +### +function benchmarkJYTmv() + funcName = "JYTmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp); + trial = @benchmark JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJTmv() + funcName = "JTmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp); + trial = @benchmark JTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJYmv() + funcName = "JYmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JYmv(L,copy(Zout),theta,Y,tmp); + trial = @benchmark JYmv(L,copy(Zout),theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJmv() + funcName = "Jmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + dtheta = randn(TYPE, size(theta)) + #Warmup + Z1 = Jmv(L,copy(dtheta),copy(Zout),theta,Y,tmp); + trial = @benchmark Jmv(L,copy(dtheta),copy(Zout),theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJthetamv() + funcName = "Jthetamv" + _,_,tmp = apply(L,theta,Y,true) + dtheta = randn(TYPE, size(theta)) + + #Warmup + Z1 = Jthetamv(L,dtheta,theta,Y,tmp); + trial = @benchmark Jthetamv(L,dtheta,theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end \ No newline at end of file diff --git a/benchmarks/micro/bm_singlelayer.jl b/benchmarks/micro/bm_singlelayer.jl new file mode 100644 index 0000000..8190b44 --- /dev/null +++ b/benchmarks/micro/bm_singlelayer.jl @@ -0,0 +1,104 @@ +using Meganet, BenchmarkTools + +const history = Pkg.dir("Meganet")*"//benchmarks//micro//bm_singleLayer.jld" + +TYPE = Float32 +K = getConvGEMMKernel(TYPE,[16,16],[3,3,64,64]) +nex = 1000 + +nLayer = getBatchNormLayer(TYPE,[16*16;64],isTrainable=true) +L = getSingleLayer(TYPE,K,nLayer) + +theta = initTheta(L) + +Y = randn(TYPE,nFeatIn(L),nex) + + +function benchmarkApply() + funcName = "apply" + Zd,Z,tmp = apply(L,theta,Y,true) + @code_warntype apply(L,theta,Y,true) + + trial = @benchmark apply(L,theta,Y,true) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJYTmv() + funcName = "JYTmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp); + trial = @benchmark JYTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJthetaTmv() + funcName = "JthetaTmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JthetaTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp); + trial = @benchmark JthetaTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJTmv() + funcName = "JTmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp); + trial = @benchmark JTmv(L,copy(Zout),(TYPE)[],theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJYmv() + funcName = "JYmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + + #Warmup + Z1 = JYmv(L,copy(Zout),theta,Y,tmp); + trial = @benchmark JYmv(L,copy(Zout),theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJmv() + funcName = "Jmv" + _,_,tmp = apply(L,theta,Y,true) + Zout = randn(TYPE,nFeatOut(L),nex) + dtheta = randn(TYPE, size(theta)) + #Warmup + Z1 = Jmv(L,copy(dtheta),copy(Zout),theta,Y,tmp); + trial = @benchmark Jmv(L,copy(dtheta),copy(Zout),theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end + +function benchmarkJthetamv() + funcName = "Jthetamv" + _,_,tmp = apply(L,theta,Y,true) + dtheta = randn(TYPE, size(theta)) + + #Warmup + Z1 = Jthetamv(L,dtheta,theta,Y,tmp); + trial = @benchmark Jthetamv(L,dtheta,theta,Y,tmp) + Meganet.updatehistory!(history, trial, funcName) + hist = JLD.load(history, funcName) + judge(hist) +end \ No newline at end of file diff --git a/examples/EResNN_CIFAR10.jl b/examples/EResNN_CIFAR10.jl index efa3f01..9afcbbc 100644 --- a/examples/EResNN_CIFAR10.jl +++ b/examples/EResNN_CIFAR10.jl @@ -85,8 +85,8 @@ solve(opt,objFun::dnnObjFctn,[vec(theta);vec(W)],Y_train,C_train,Y_test,C_test) # Profile.clear() # Profile.clear_malloc_data() # Profile.init(n = 10^7, delay = 0.01) -# @profile solve(opt,objFun::dnnObjFctn,[vec(theta);vec(W)],Y,C,Y,C) +# @profile solve(opt,objFun::dnnObjFctn,[vec(theta);vec(W)],Y_train,C_train,Y_test,C_test) # open("/tmp/EREsNN_CIFAR10.txt", "w") do s - # Profile.print(IOContext(s, :displaysize => (24, 500))) +# Profile.print(IOContext(s, :displaysize => (24, 500))) # end diff --git a/src/Meganet.jl b/src/Meganet.jl index 30d1aaf..3e63a76 100644 --- a/src/Meganet.jl +++ b/src/Meganet.jl @@ -22,6 +22,7 @@ include("kernelTypes/convCircKernel.jl"); include("layers/affineScalingLayer.jl") include("layers/normLayer.jl") +include("integrators/batchNormNN.jl") include("layers/doubleSymLayer.jl") include("layers/singleLayer.jl") diff --git a/src/activations/tanhActivation.jl b/src/activations/tanhActivation.jl index 1c7ab8a..06af3de 100644 --- a/src/activations/tanhActivation.jl +++ b/src/activations/tanhActivation.jl @@ -19,11 +19,11 @@ export tanhActivation """ function tanhActivation(Y::Array{T,2},doDerivative::Bool=false) where {T <: Number} -A = tanh.(Y) -if doDerivative - dA = 1-A.^2 -else - dA = zeros(T,0,0) -end -return A,dA + + A = tanh.(Y) + dA = zeros(A) + if doDerivative + dA .= one(T) .- A.^2 + end + return A, dA end diff --git a/src/integrators/NN.jl b/src/integrators/NN.jl index 7d3ee0f..9484a7e 100644 --- a/src/integrators/NN.jl +++ b/src/integrators/NN.jl @@ -5,10 +5,10 @@ NN Neural Network block Y_k+1 = layer{k}(theta{k},Y_k) """ -mutable struct NN{T} <: AbstractMeganetElement{T} - layers ::Array{AbstractMeganetElement{T}, 1} # layers of Neural Network, cell array - outTimes - Q +mutable struct NN{T, TQ <: Union{Array{T,2},UniformScaling{Int}}} <: AbstractMeganetElement{T} + layers ::Array{AbstractMeganetElement{T}, 1} # layers of Neural Network, cell array + outTimes ::Array{Int,1} + Q :: TQ end function getNN(layers::Array{AbstractMeganetElement{T}},outTimes=eye(Int,length(layers))[:,end],Q=I) where {T <: Number} @@ -21,7 +21,7 @@ function getNN(layers::Array{AbstractMeganetElement{T}},outTimes=eye(Int,length( end nout = nFeatOut(layers[k]) end - return NN{T}(layers,outTimes,Q); + return NN(layers,outTimes,Q); end @@ -36,14 +36,14 @@ end # ---------- counting thetas, input and output features ----- function nTheta(this::NN) - n = 0; + n::Int = 0; for k=1:length(this.layers) - n = n + nTheta(this.layers[k]); + n += nTheta(this.layers[k]); end return n end nFeatIn(this::NN) = nFeatIn(this.layers[1]) -nFeatOut(this::NN) = nFeatOut(this.layers[end]) +nFeatOut(this::NN)::Int = nFeatOut(this.layers[end]) function nDataOut(this::NN) n=0; @@ -159,39 +159,41 @@ end function JthetaTmv(this::NN{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - return JTmv(this,Wdata,W,theta,Y,tmp)[1]; + return JTmv(this,Wdata,W,theta,Y,tmp)[1]; # TODO: Why calculating both, Can be more efficient? end -function JTmv(this::NN,Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - nex = div(length(Y),nFeatIn(this)) +function JTmv(this::NN{T},Wdata::Array{T},Win::Array{T},theta::Array{T},Y::Array{T},tmp)::Tuple{Array{T,1},Array{T,1}} where {T <: Number} + # WOW THIS IS HACKED BIG TIME. Need to find a way to type stabalize W (not ez) + #TODO: Make this type stable - Some internals are not stable + nex = div(length(Y),nFeatIn(this)::Int) if size(Wdata,1)>0 Wdata = reshape(Wdata,:,nex) end - if length(W)==0 + + if length(Win)==0 W = zeros(T,nFeatOut(this),nex) - elseif length(W)>1 - W = reshape(W,:,nex) + else + W = reshape(Win,:,nex) end - - dtheta = 0*theta + dtheta = zero(T)*theta nt = length(this.layers) cnt = 0; cnt2 = 0 for i=nt:-1:1 if this.outTimes[i]==1 - nn = nFeatOut(this.layers[i]) + nn = nFeatOut(this.layers[i])::Int W += this.Q'*Wdata[end-cnt2-nn+1:end-cnt2,:] cnt2 = cnt2 + nn end - ni = nTheta(this.layers[i]) + ni = nTheta(this.layers[i])::Int dmbi,W = JTmv(this.layers[i],W,zeros(T,0),theta[end-cnt-ni+1:end-cnt],tmp[i,1],tmp[i,2]) dtheta[end-cnt-ni+1:end-cnt] = dmbi - cnt = cnt+ni + cnt += ni end - return vec(dtheta), vec(W) + return vec(dtheta), vec(W) end diff --git a/src/integrators/ResNN.jl b/src/integrators/ResNN.jl index 7e5b7bd..86f43e9 100644 --- a/src/integrators/ResNN.jl +++ b/src/integrators/ResNN.jl @@ -5,20 +5,20 @@ Residual Neural Network block Y_k+1 = Y_k + h*layer{k}(theta{k},Y_k) """ -mutable struct ResNN{T} <: AbstractMeganetElement{T} - layer - nt - h - outTimes +mutable struct ResNN{T, TL <: AbstractMeganetElement{T}} <: AbstractMeganetElement{T} #TODO limit TL more + layer :: TL + nt :: Int + h :: T + outTimes :: Array{Int,1} Q end function getResNN(TYPE::Type,layer,nt,h=one(TYPE),outTimes=eye(Int,nt)[:,nt],Q=I) - h = convert(TYPE,h); + h = convert(TYPE,h); if nFeatIn(layer)!=nFeatOut(layer) error("ResNN layer must be square!") - end - return ResNN{TYPE}(layer,nt,h,outTimes,Q) + end + return ResNN(layer,nt,h,outTimes,Q) end @@ -48,7 +48,7 @@ function initTheta(this::ResNN{T}) where {T<:Number} end # ------- apply forward problems ----------- -function apply(this::ResNN{T},theta::Array{T},Y0::Array{T},doDerivative=true) where {T<:Number} +function apply(this::ResNN{T},theta_in::Array{T},Y0::Array{T},doDerivative=true) where {T<:Number} nex = div(length(Y0),nFeatIn(this)) Y = reshape(Y0,:,nex) @@ -57,9 +57,9 @@ function apply(this::ResNN{T},theta::Array{T},Y0::Array{T},doDerivative=true) w tmp[1,1] = Y0 end - theta = reshape(theta,:,this.nt) + theta = reshape(theta_in,:,this.nt) - Ydata = zeros(T,0,nex) + Ydata::Array{T,2} = zeros(T,0,nex) for i=1:this.nt Z,dummy,tmp[i,2] = apply(this.layer,theta[:,i],Y,doDerivative) Y += this.h * Z @@ -115,26 +115,26 @@ end # -------- Jacobian transpose matvecs ---------------- -function JYTmv(this::ResNN{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T<:Number} +function JYTmv(this::ResNN{T},Wdata_in::Array{T},Win::Array{T},theta_in::Array{T},Y::Array{T},tmp) where {T<:Number} nex = div(length(Y),nFeatIn(this)) - if length(Wdata)>0 - Wdata = reshape(Wdata,:,sum(this.outTimes),nex) + if length(Wdata_in)>0 + Wdata = reshape(Wdata_in,:,sum(this.outTimes),nex) end - if length(W)==0 - W = zero(T) + if length(Win)==0 + W = zeros(T,0,0) else - W = reshape(W,:,nex) + W = reshape(Win,:,nex) end - theta = reshape(theta,:,this.nt) + theta = reshape(theta_in,:,this.nt) cnt = sum(this.outTimes) for i=this.nt:-1:1 if this.outTimes[i]==1 - W += this.Q'*Wdata[:,cnt,:] + W::Array{T,2} = this.Q'*Wdata[:,cnt,:] cnt = cnt-1 end - dW = JYTmv(this.layer,W,zeros(T,0),theta[:,i],tmp[i,1],tmp[i,2]) + dW::Array{T,2} = JYTmv(this.layer,W,zeros(T,0),theta[:,i],tmp[i,1],tmp[i,2]) W += this.h*dW end return W @@ -166,7 +166,9 @@ function JTmv(this::ResNN{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Arra W += this.Q'* Wdata[:,cnt,:] cnt = cnt-1 end + dmbi,dW = JTmv(this.layer,W,zeros(T,0),theta[:,i],tmp[i,1],tmp[i,2]) + dtheta[:,i] = this.h*dmbi W += this.h*dW end diff --git a/src/integrators/batchNormNN.jl b/src/integrators/batchNormNN.jl new file mode 100644 index 0000000..7e1e789 --- /dev/null +++ b/src/integrators/batchNormNN.jl @@ -0,0 +1,190 @@ +export batchNormNN,getbatchNormNN,initTheta + +""" +batchNormNN Neural Network block + + Y_k+1 = layer{k}(theta{k},Y_k) +""" +#TODO: Can probably optimize some functions using the knowledge that we only have a norm and an AFS layer. +mutable struct batchNormNN{T,TQ <: Union{Array{T,2},UniformScaling{Int}}} <: AbstractMeganetElement{T} + layers ::Tuple{normLayer{T}, AffineScalingLayer{T}} + outTimes ::Array{Int,1} + Q ::TQ +end + +function getbatchNormNN(layers::Tuple{normLayer{T}, AffineScalingLayer{T}},outTimes=eye(Int,length(layers))[:,end],Q=I) where {T <: Number} + nt = length(layers) + nout = nFeatOut(layers[1]) + + for k=2:nt + if nFeatIn(layers[k]) != nout + error("Dim. of input features of block $k does not match dim. of output features of block $(k-1)"); + end + nout = nFeatOut(layers[k]) + end + return batchNormNN(layers,outTimes,Q); +end + +# ---------- counting thetas, input and output features ----- +function nTheta(this::batchNormNN) + n::Int = 0; + for k=1:length(this.layers) + n += nTheta(this.layers[k]); + end + return n +end +nFeatIn(this::batchNormNN) = nFeatIn(this.layers[1]) +nFeatOut(this::batchNormNN) = nFeatOut(this.layers[end]) + +function nDataOut(this::batchNormNN) + n=0; + for k=1:length(this.layers) + n = n+this.outTimes[k]* nFeatOut(this.layers[k]); + end +end + +function initTheta(this::batchNormNN{T}) where {T <: Number} + theta = zeros(T,0) + for k=1:length(this.layers) + theta = [theta; vec(initTheta(this.layers[k]))] + end + return convert(Array{T},theta) +end + + +# --------- forward problem ---------- +function apply(this::batchNormNN{T},theta::Array{T},Y0::Array{T,2},doDerivative=true) where {T<:Number} + Y::Array{T,2} = copy(Y0) + nex = div(length(Y),nFeatIn(this))::Int + nt = length(this.layers) + + tmp = Array{Any}(nt+1,2) + if doDerivative + tmp[1,1] = Y0 + end + + Ydata::Array{T,2} = zeros(T,0,nex) + cnt = 0 + for i=1:nt + ni = nTheta(this.layers[i])::Int + + Yd::Array{T,2}, Y, tmp[i,2] = apply(this.layers[i],theta[cnt+(1:ni)],Y,doDerivative) + if this.outTimes[i]==1 + Ydata = [Ydata; this.Q*Yd] + end + if doDerivative + tmp[i+1,1] = copy(Y) + end + cnt = cnt + ni + end + + return Ydata,Y,tmp +end + +# -------- Jacobian matvecs -------- +function JYmv(this::batchNormNN{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + nex = div(length(Y),nFeatIn(this)) + nt = length(this.layers) + cnt = 0 + dYdata = zeros(T,0,nex) + for i=1:nt + ni = nTheta(this.layers[i]) + dY = JYmv(this.layers[i],dY,theta[cnt+(1:ni)],tmp[i,1],tmp[i,2])[2] + if this.outTimes[i]==1 + dYdata = [dYdata; this.Q*dY] + end + cnt = cnt+ni + end + return dYdata, dY +end + +function Jmv(this::batchNormNN{T},dtheta::Array{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + nex = div(length(Y),nFeatIn(this)) + nt = length(this.layers); + if isempty(dY) + dY = 0*Y + end + + dYdata = zeros(T,0,nex) + cnt = 0 + for i=1:nt + ni = nTheta(this.layers[i]) + dY = Jmv(this.layers[i],dtheta[cnt+(1:ni)],dY,theta[cnt+(1:ni)], + tmp[i,1],tmp[i,2])[2] + if this.outTimes[i]==1 + dYdata = [dYdata; this.Q*dY] + end + cnt = cnt+ni + end + return dYdata,dY +end + +# -------- Jacobian' matvecs -------- +function JYTmv(this::batchNormNN{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + + nex = div(length(Y),nFeatIn(this)); + if !isempty(Wdata) + Wdata = reshape(Wdata,:,nex); + end + if isempty(W) + W = zero(T) + elseif length(W)>1 + W = reshape(W,:,nex) + end + nt = length(this.layers) + + cnt = 0; cnt2 = 0; + for i=nt:-1:1 + ni = nTheta(this.layers[i]) + if this.outTimes[i]==1 + nn = nFeatOut(this.layers[i]) + W = W + this.Q'*Wdata[end-cnt2-nn+1:end-cnt2,:] + cnt2 = cnt2 + nn + end + W = JYTmv(this.layers[i], W,(T)[],theta[end-cnt-ni+1:end-cnt], + tmp[i,1],tmp[i,2]) + cnt = cnt+ni + end + return W +end + + +function JthetaTmv(this::batchNormNN{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + return JTmv(this,Wdata,W,theta,Y,tmp)[1]; +end + + + +function JTmv(this::batchNormNN,Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + + nex = div(length(Y),nFeatIn(this)) + + if size(Wdata,1)>0 + Wdata = reshape(Wdata,:,nex) + end + if length(W)==0 + W = zeros(T,nFeatOut(this),nex) + elseif length(W)>1 + W = reshape(W,:,nex) + end + + dtheta = 0*theta + nt = length(this.layers) + + cnt = 0; cnt2 = 0 + for i=nt:-1:1 + if this.outTimes[i]==1 + nn = nFeatOut(this.layers[i]) + W += this.Q'*Wdata[end-cnt2-nn+1:end-cnt2,:] + cnt2 = cnt2 + nn + end + + ni = nTheta(this.layers[i]) + + dmbi,W = JTmv(this.layers[i],W,zeros(T,0),theta[end-cnt-ni+1:end-cnt],tmp[i,1],tmp[i,2]) + dtheta[end-cnt-ni+1:end-cnt] = dmbi + cnt = cnt+ni + end + return vec(dtheta), vec(W) + +end diff --git a/src/integrators/connector.jl b/src/integrators/connector.jl index b18414b..5824642 100644 --- a/src/integrators/connector.jl +++ b/src/integrators/connector.jl @@ -1,10 +1,10 @@ export Connector,getConnector -mutable struct Connector{T} <: AbstractMeganetElement{T} - K - b - outTimes - Q +mutable struct Connector{T,TQ <: Union{Array{T,2},UniformScaling{Int}}, TK <: Union{Array{T,2},SparseMatrixCSC{T,Int}}} <: AbstractMeganetElement{T} + K::TK + b::T + outTimes::Int + Q::TQ # ??? end nTheta(this::Connector) = 0 @@ -14,7 +14,7 @@ nDataOut(this::Connector) = ((this.Q==I) ? nFeatOut(this) : size(this.Q,1)) initTheta(this::Connector{T}) where {T <: Number} = zeros(T,0) function getConnector(TYPE::Type, K; b = zero(TYPE),outTimes=0,Q=I) - return Connector{TYPE}(K,b,outTimes,Q); + return Connector(K,b,outTimes,Q); end @@ -22,39 +22,38 @@ function apply(this::Connector{T},theta::Array{T},Y0::Array{T},doDerivative=true nex = div(length(Y0),nFeatIn(this)) Y0 = reshape(Y0,:,nex) Y = this.K*Y0 .+ this.b + Ydata::Array{T,2} = Array{T, 2}(0, 0) # Temporary fix until we know what type Q is if this.outTimes==1 Ydata = this.Q*Y - else - Ydata = Array{T, 2}(0, 0) end tmp = Y0; return Ydata, Y, tmp end function Jmv(this::Connector{T},dtheta::Array{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} - + # ??? This doesn't seem to get used? nex = div(length(dY),nFeatIn(this)) dY = reshape(dY,:,nex) dY = this.K*dY + dYdata::Array{T,2} = Array{T, 2}(0, 0) # Temporary fix until we know what type Q is if this.outTimes==1 dYdata = this.Q*dY - else - dYdata = [] end + return dYdata,dY end -function JTmv(this::Connector{T},Wdata::Array{T},W::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} +function JTmv(this::Connector{T},Wdata::Array{T},Win::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} nex = div(length(Y),nFeatIn(this)) - if length(W)==0 - W = zero(T); + if length(Win)==0 + W = zeros(T,1,1); else - W = reshape(W,:,nex); + W = reshape(Win,:,nex); end if length(Wdata)>0 Wdata = reshape(Wdata,:,nex); - W = W+ this.Q'*Wdata; + W = W .+ this.Q'*Wdata end dtheta = zeros(T,0); diff --git a/src/kernelTypes/convGEMMKernel.jl b/src/kernelTypes/convGEMMKernel.jl index c9216bb..09c68b4 100644 --- a/src/kernelTypes/convGEMMKernel.jl +++ b/src/kernelTypes/convGEMMKernel.jl @@ -15,11 +15,11 @@ function Amv(this::convGEMMKernel{T},theta::Array{T},Y::Array{T}) where {T<:Numb nex = div(numel(Y),prod(nImgIn(this))) # compute convolution Y = reshape(Y,nImg[1],nImg[2],this.sK[3],nex); - AY = zeros(T,nImg[1]*nImg[2],this.sK[4],nex); + AY = Array{T, 3}(nImg[1]*nImg[2],this.sK[4],nex); aux = zeros(T,nImg[1],nImg[2],this.sK[3]); AYk = zeros(T,nImg[1]*nImg[2],this.sK[4]); ### reshape the kernels for gemm!: - K = reshape(theta,tuple(sK...)); + K = reshape(theta, sK[1], sK[2], sK[3], sK[4]) KK = Array{Array{T,2}}(sK[1],sK[2]); for k1 = 1:sK[1] for k2 = 1:sK[2] @@ -34,16 +34,16 @@ function Amv(this::convGEMMKernel{T},theta::Array{T},Y::Array{T}) where {T<:Numb @inbounds AY[:,:,k] = AYk; AYk[:] = zero(T) end - AY = reshape(AY,:,nex); - return AY + AY_out = reshape(AY,:,nex); + return AY_out end -function ATmv(this::convGEMMKernel{T},theta::Array{T},Z::Array{T}) where {T<:Number} +function ATmv(this::convGEMMKernel{T},theta::Array{T},Zin::Array{T}) where {T<:Number} nImg = this.nImg; sK = this.sK; - nex = div(numel(Z),prod(nImgOut(this))); - K = reshape(theta,tuple(sK...)); - Z = reshape(Z,nImg[1],nImg[2],sK[4],nex); + nex = div(numel(Zin),prod(nImgOut(this))); + K = reshape(theta, sK[1], sK[2], sK[3], sK[4]); + Z = reshape(Zin,nImg[1],nImg[2],sK[4],nex); aux = zeros(T,nImg[1],nImg[2],sK[4]); ATZ = zeros(T,nImg[1]*nImg[2],sK[3],nex); ATZk = zeros(T,nImg[1]*nImg[2],sK[3]); @@ -64,8 +64,8 @@ function ATmv(this::convGEMMKernel{T},theta::Array{T},Z::Array{T}) where {T<:Num @inbounds ATZ[:,:,k] = ATZk; ATZk[:] = zero(T) end - ATZ = reshape(ATZ,:,nex); - return ATZ + ATZ_out = reshape(ATZ,:,nex); + return ATZ_out end function Jthetamv(this::convGEMMKernel{T},dtheta::Array{T},dummy::Array{T},Y::Array{T},temp=nothing) where {T<:Number} @@ -74,38 +74,39 @@ function Jthetamv(this::convGEMMKernel{T},dtheta::Array{T},dummy::Array{T},Y::Ar return Z end -function JthetaTmv(this::convGEMMKernel{T},Z::Array{T},dummy::Array{T},Y::Array{T}) where {T<:Number} +function JthetaTmv(this::convGEMMKernel{T}, Zin::Array{T}, dummy::Array{T}, Yin::Array{T}) where {T<:Number} # derivative of Z*(A(theta)*Y) w.r.t. theta - sK = this.sK; - nImg = this.nImg; - nex = div(numel(Y),prod(nImgIn(this))) + sK = this.sK + nImg = this.nImg + nex = div(numel(Yin),prod(nImgIn(this))) # compute convolution - Y = reshape(Y,nImg[1],nImg[2],this.sK[3],nex); - Z = reshape(Z,nImg[1]*nImg[2],this.sK[4],nex); - Zk = zeros(T,nImg[1]*nImg[2],this.sK[4]); - aux = zeros(T,nImg[1],nImg[2],this.sK[3]); + Y = reshape(Yin, nImg[1], nImg[2], this.sK[3], nex) + Z = reshape(Zin, nImg[1]*nImg[2], this.sK[4], nex) + Zk = zeros(T, nImg[1]*nImg[2], this.sK[4]) + aux = zeros(T, nImg[1], nImg[2], this.sK[3]) + ### reshape the kernels for gemm!: - dtheta = zeros(T,tuple(sK...)); - KK = Array{Array{T,2}}(sK[1],sK[2]); + dtheta = zeros(T, sK[1], sK[2], sK[3], sK[4]) + KK = Array{Array{T, 2}}(sK[1], sK[2]) for k1 = 1:sK[1] for k2 = 1:sK[2] - @inbounds KK[k1,k2] = zeros(T,sK[3],sK[4]); + @inbounds KK[k1, k2] = zeros(T, sK[3], sK[4]) end end - shiftX = [0;-1;0;0;1;0]; - shiftT = [1;0;0;0;0;-1]; + shiftX = [0;-1;0;0;1;0] + shiftT = [1;0;0;0;0;-1] for k = 1:nex - getColumn!(Z,Zk,k); - multConv2Dblock(Y,KK, Zk,aux,shiftX,shiftT,k,doDerivative = 1); + getColumn!(Z, Zk, k) + multConv2Dblock(Y, KK, Zk, aux, shiftX, shiftT, k, doDerivative = 1) end ### Assemble the kernels from gemm!: for k1 = 1:sK[1] for k2 = 1:sK[2] - @inbounds dtheta[k1,k2,:,:] = KK[k1,k2]; + @inbounds dtheta[k1, k2, :, :] = KK[k1, k2] end end - dtheta = reshape(dtheta,tuple(this.sK...)); - return dtheta + dtheta_out = reshape(dtheta, sK[1], sK[2], sK[3], sK[4]) + return dtheta_out end @@ -118,60 +119,65 @@ for c=1:size(Z,2) end end -function multConv2Dblock(x::Array{T},K::Array{Array{T,2},2}, y::Array{T}, t::Array{T},shiftX,shiftT,imIdx;doDerivative = 0) where {T<:Number} -## y = K*x -## K - 3X3 array of Arrays -## x - a vector of length |nImgag+2|*cin (zero padded) -## y - a vector of length |nImgag|*cout - -nImg1 = size(x,1); -nImg2 = size(x,2); -cin = size(x,3); -cout = size(y,2); -OneType = one(T); - -kernelWidth = size(K,1); -# y = reshape(y,nImg1*nImg2,cout); # it is supposed to be of this shape... -k=1; -jt=0;it=0;jt=0;jx=0; -for p = 1:2:2*kernelWidth - for q = 1:2:2*kernelWidth - t = reshape(t,nImg1,nImg2,cin); - for cc = 1:cin - jx = 1+shiftX[p]; - jt = 1+shiftT[p]; - if jt > 1 - @inbounds t[:,1:(jt-1),cc] = 0.0; - end - while jt < nImg2+shiftT[p+1] - it = 1+shiftT[q]; - ix = 1+shiftX[q]; - if it > 1 - @inbounds t[1:(it-1),jt,cc] = 0.0; +function multConv2Dblock(x::Array{T},K::Array{Array{T,2},2}, y::Array{T}, tin::Array{T},shiftX,shiftT,imIdx;doDerivative = 0) where {T<:Number} + ## y = K*x + ## K - 3X3 array of Arrays + ## x - a vector of length |nImgag+2|*cin (zero padded) + ## y - a vector of length |nImgag|*cout + + nImg1 = size(x,1); + nImg2 = size(x,2); + cin = size(x,3); + cout = size(y,2); + OneType = one(T); + t = reshape(tin,nImg1,nImg2,cin); + kernelWidth = size(K,1); + # y = reshape(y,nImg1*nImg2,cout); # it is supposed to be of this shape... + k=1; + jt=0;it=0;jt=0;jx=0; + for p = 1:2:2*kernelWidth + for q = 1:2:2*kernelWidth + lower = nImg2+shiftT[p+1] # Move outside of the forloop for increased speed + upper = nImg1+shiftT[q+1] # Move outside of the forloop for increased speed + for cc = 1:cin + jx = 1+shiftX[p]; # Moving these outside didn't seem to help + jt = 1+shiftT[p]; + if jt > 1 + @inbounds t[:,1:(jt-1),cc] = 0.0; end - while it < nImg1+shiftT[q+1] - @inbounds t[it,jt,cc] = x[ix,jx,cc,imIdx]; - it+=1;ix+=1; + while jt <= lower + it = 1+shiftT[q]; + ix = 1+shiftX[q]; + if it > 1 + for ii = 1:(it-1) + @inbounds t[ii,jt,cc] = zero(T) #@inbounds t[1:(it-1),jt,cc] = 0.0 - faster unvectorized + end + end + while it <= upper + @inbounds t[it,jt,cc] = x[ix,jx,cc,imIdx]; + it+=1;ix+=1; + end + if it <= nImg1 + for ii = it:nImg1 + @inbounds t[ii,jt,cc] = zero(T) #@inbounds t[it:nImg1,jt,cc] = 0.0 - faster unvectorized + end + end + jt+=1;jx+=1; + end - if it <= nImg1 - @inbounds t[it:nImg1,jt,cc] = 0.0; + if jt <= nImg2 + @inbounds t[:,jt:nImg2,cc] = 0.0; end - jt+=1;jx+=1; end - if jt <= nImg2 - @inbounds t[:,jt:nImg2,cc] = 0.0; + if doDerivative == 0 + BLAS.gemm!('N','T',OneType,reshape(t,nImg1*nImg2,cin),K[k],OneType,y); + else + BLAS.gemm!('T','N',OneType,reshape(t,nImg1*nImg2,cin),y,OneType,K[k]); end + k+=1; end - t = reshape(t,nImg1*nImg2,cin); - if doDerivative == 0 - BLAS.gemm!('N','T',OneType,t,K[k],OneType,y); - else - BLAS.gemm!('T','N',OneType,t,y,OneType,K[k]); - end - k+=1; end -end -return y; + return y; end diff --git a/src/layers/affineScalingLayer.jl b/src/layers/affineScalingLayer.jl index fc15333..64a2da0 100644 --- a/src/layers/affineScalingLayer.jl +++ b/src/layers/affineScalingLayer.jl @@ -21,24 +21,23 @@ function splitWeights(this::AffineScalingLayer{T},theta_in::Array{T}) where {T < return s2, b2 end -function scaleChannels(Y,s,b) +function scaleChannels!(Y::Array{T},s::Array{T},b::Array{T}) where {T <: Number} for i=1:length(s) - Y[:,i,:] = s[i]*Y[:,i,:] + b[i] + Y[:,i,:] .= Y[:,i,:].*s[i] .+ b[i] end - return Y end -function apply(this::AffineScalingLayer{T},theta::Array{T},Yin::Array{T},doDerivative=false) where {T <: Number} +function apply(this::AffineScalingLayer{T},theta::Array{T},Y::Array{T},doDerivative=false) where {T <: Number} - Y = reshape(Yin,this.nData[1], this.nData[2],:) + Y = reshape(copy(Y),this.nData[1], this.nData[2],:) dA = (T)[] nex = size(Y,3) s2,b2 = splitWeights(this,theta); - Yscaled = scaleChannels(Y,s2,b2); + scaleChannels!(Y,s2,b2); - Yout = reshape(Yscaled,:,nex) + Yout = reshape(Y,:,nex) Ydata = Yout return Ydata, Yout, dA end @@ -65,14 +64,14 @@ function initTheta(this::AffineScalingLayer{T}) where {T <: Number} end function Jthetamv(this::AffineScalingLayer{T},dtheta::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} - Y = reshape(copy(Y),this.nData[1], this.nData[2],:) nex = size(Y,3) ds2,db2 = splitWeights(this,dtheta) - dY = scaleChannels(Y,ds2,db2) - dY = reshape(dY,:,nex) + scaleChannels!(Y,ds2,db2) + + dY = reshape(Y,:,nex) dYdata = dY return dYdata, dY end @@ -88,12 +87,11 @@ function JthetaTmv(this::AffineScalingLayer{T},Z::Array{T},dummy::Array{T},theta end function JYmv(this::AffineScalingLayer{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} - dY = reshape(copy(dY),this.nData[1], this.nData[2],:); nex = size(dY,3) s2,b2 = splitWeights(this,theta); - dY = scaleChannels(dY,s2,b2*0) + scaleChannels!(dY,s2,b2*0) dY = reshape(dY,:,nex) dYdata = dY @@ -101,11 +99,10 @@ function JYmv(this::AffineScalingLayer{T},dY::Array{T},theta::Array{T},Y::Array{ end function JYTmv(this::AffineScalingLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp=nothing) where {T <: Number} - Z = reshape(copy(Z),this.nData[1], this.nData[2],:) nex = size(Z,3) s2,b2 = splitWeights(this,theta) - Z = scaleChannels(Z,s2,b2*0) + scaleChannels!(Z,s2,b2*0) return reshape(Z,:,nex) end diff --git a/src/layers/doubleSymLayer.jl b/src/layers/doubleSymLayer.jl index 633e613..5f48cf3 100644 --- a/src/layers/doubleSymLayer.jl +++ b/src/layers/doubleSymLayer.jl @@ -5,7 +5,7 @@ export DoubleSymLayer,getDoubleSymLayer Y(theta,Y0) = K(th1)'(activation( K(th1)\*Y0 + trafo.Bin\*th2))) + trafo.Bout\*th3 """ -mutable struct DoubleSymLayer{T, TK <: AbstractConvKernel{T}, TN <: Union{NN{T}, normLayer{T}}} <: AbstractMeganetElement{T} +mutable struct DoubleSymLayer{T, TK <: AbstractConvKernel{T}, TN <: Union{batchNormNN{T}, normLayer{T}}} <: AbstractMeganetElement{T} activation :: Function # activation function K :: TK # Kernel model, e.g., convMod nLayer :: TN # normalization layer @@ -170,43 +170,46 @@ function JthetaTmv(this::DoubleSymLayer{T},Z::Array{T},dummy::Array{T},theta::Ar return dtheta end -function JYTmv(this::DoubleSymLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T<:Number} +function JYTmv(this::DoubleSymLayer{T},Zin::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T<:Number} nex = div(length(Y),nFeatIn(this)) - Z = reshape(Z,:,nex) + Z = reshape(Zin,:,nex) th1,th2,th3,th4 = splitWeights(this,theta) Kop = getOp(this.K,th1) A,dA = this.activation(tmp[2],true) dAZ = dA.*(Kop*Z) dAZ = JYTmv(this.nLayer,dAZ,(T)[],th4,Kop*Y,tmp[1]) - dAZ = reshape(dAZ,:,nex) - dY = -(Kop'*dAZ) + dAZ_out = reshape(dAZ,:,nex) + dY = -(Kop'*dAZ_out) return dY end -function JTmv(this::DoubleSymLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T<:Number} +function JTmv(this::DoubleSymLayer{T}, Zin::Array{T}, dummy::Array{T}, + theta::Array{T}, Yin::Array{T}, tmp) where {T<:Number} - dY = (T)[] - nex = div(length(Y),nFeatIn(this)) - Z = reshape(Z,:,nex) - Yt = reshape(tmp[2],:,nex) - Y = reshape(Y,:,nex) - th1, th2,th3,th4 = splitWeights(this,theta) - Kop = getOp(this.K,th1) - A,dA = this.activation(Yt,true) + nex = div(length(Yin),nFeatIn(this)) + Z = reshape(Zin, :, nex) + Yt = reshape(tmp[2]::Array{T,2},:,nex) + Y = reshape(Yin,:,nex) + th1, th2, th3, th4 = splitWeights(this,theta) + #Kop = getOp(this.K,th1) + A::Array{T,2}, dA::Array{T,2} = this.activation(Yt,true) dth3 = vec(sum(this.Bout'*Z,2)) - dAZ = dA.*(Kop*Z) - dth2 = vec(sum(this.Bin'*dAZ,2)) - dth4,dAZ = JTmv(this.nLayer,dAZ,zeros(T,0),th4,Kop*Y,tmp[1]) - dth1 = JthetaTmv(this.K,dAZ,zeros(T,0),Y) + KopZ = Amv(this.K, th1, Z) + dAZ1 = dA.*KopZ - dth1 = dth1 + JthetaTmv(this.K,A,(T)[],Z) + dth2 = vec(sum(this.Bin'*dAZ1,2)) + KopY = Amv(this.K, th1, Y) + dth4, dAZ2 = JTmv(this.nLayer,dAZ1,zeros(T,0),th4,KopY,tmp[1]) + dth1 = JthetaTmv(this.K,dAZ2,zeros(T,0),Y) + dth1 += JthetaTmv(this.K,A,(T)[],Z) dtheta = [-vec(dth1); -vec(dth2); vec(dth3);-vec(dth4)] - dAZ = reshape(dAZ,:,nex) - dY = -(Kop'*dAZ) - return dtheta,dY + dAZ_out = reshape(dAZ2,:,nex) + KopTdAZ = ATmv(this.K, th1, dAZ_out) + dY = -KopTdAZ + return dtheta, dY end diff --git a/src/layers/normLayer.jl b/src/layers/normLayer.jl index 22ecfe1..be6d4f7 100644 --- a/src/layers/normLayer.jl +++ b/src/layers/normLayer.jl @@ -15,7 +15,7 @@ function getBatchNormLayer(TYPE::Type, nData; eps = convert(TYPE,1e-3),isTrainab L = normLayer{TYPE}(nData,3,eps) if isTrainable SL = AffineScalingLayer{TYPE}(nData) - return getNN([L;SL]); + return getbatchNormNN((L,SL)); else return L; end @@ -25,13 +25,13 @@ function getTVNormLayer(TYPE::Type,nData;eps = convert(TYPE,1e-3),isTrainable::B L = normLayer{TYPE}(nData,2,eps) if isTrainable SL = AffineScalingLayer{TYPE}(nData) - return getNN([L;SL]) + return getbatchNormNN((L,SL)) else return L end end -function apply(this::normLayer{T},theta::Array{T},Yin::Array{T},doDerivative=true) where {T <: Number} +function apply(this::normLayer{T},theta::Array{T},Yin::Array{T,2},doDerivative=true) where {T <: Number} # first organize Y with channels nf = this.nData[2]::Int @@ -41,17 +41,15 @@ function apply(this::normLayer{T},theta::Array{T},Yin::Array{T},doDerivative=tru dA = (T)[] # subtract mean across pixels - Ya = mean(Y,this.doNorm) - Y = Y.-Ya - # Y .-= Ya #TODO: This line is more efficient, but tests do not want Y to change. Why dont we want Y to change in place? + Yout = Y.-mean(Y,this.doNorm) # normalize - S2 = mean(Y.^2,this.doNorm) - Y ./= sqrt.(S2+this.eps) + S2 = sqrt.(mean(Yout.^2,this.doNorm) + this.eps) + Yout ./= S2 - Yout = reshape(Y,:,nex) + Yout2 = reshape(Yout,:,nex) - return Yout, Yout, dA + return Yout2, Yout2, dA end function nTheta(this::normLayer) @@ -79,26 +77,27 @@ function Jthetamv(this::normLayer,dtheta::Array{T},theta::Array{T},Y::Array{T},d return zeros(T,size(Y)), zeros(T,size(Y)) end -function JYmv(this::normLayer,dY::Array{T},theta::Array{T},Y::Array{T},dA=nothing) where {T <: Number} +function JYmv(this::normLayer,dYin::Array{T},theta::Array{T},Yin::Array{T},dA=nothing) where {T <: Number} - nex = div(length(dY),nFeatIn(this)) + nex = div(length(dYin),nFeatIn(this)) nf = this.nData[2] - dY = reshape(dY,:,nf,nex) - Y = reshape(Y,:,nf,nex) + dY = reshape(dYin,:,nf,nex) + Y = reshape(Yin,:,nf,nex) Ya = mean(Y,this.doNorm) - Y = Y .- Ya + Yout = Y .- Ya dYa = mean(dY,this.doNorm) - dY = dY .- dYa - S2y = mean(Y.^2,this.doNorm); + dYout = dY .- dYa + + S2y = mean(Yout.^2,this.doNorm); den = sqrt.(S2y+this.eps); + tmp = mean(Yout.*dYout,this.doNorm) + dYout ./= den - tmp = mean(Y.*dY,this.doNorm) - dY = dY ./ den + den .^= 3 + Yout .= Yout.*tmp ./den - Y = Y .* tmp - Y = Y ./ den.^3 - dZ = reshape(dY-Y,:,nex) + dZ = reshape(dYout-Yout,:,nex) return dZ,dZ end @@ -116,27 +115,27 @@ function JthetaTmv(this::normLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T return zeros(T,0) end -function JYTmv(this::normLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},dA=nothing) where {T <: Number} +function JYTmv(this::normLayer{T},Zin::Array{T},dummy::Array{T},theta::Array{T},Yin::Array{T},dA=nothing) where {T <: Number} - nex = div(length(Y),nFeatIn(this)) + nex = div(length(Yin),nFeatIn(this)) nf = this.nData[2] - Z = reshape(Z,:,nf,nex) - Y = reshape(Y,:,nf,nex) + Z = reshape(Zin,:,nf,nex) + Y = reshape(Yin,:,nf,nex) Ya = mean(Y,this.doNorm) - Y = Y .- Ya + Yout = Y .- Ya Za = mean(Z,this.doNorm) - Z = Z .- Za - S2y = mean(Y.^2,this.doNorm) + Zout = Z .- Za + S2y = mean(Yout.^2,this.doNorm) den = sqrt.(S2y+this.eps) - tmp = mean(Y.*Z,this.doNorm) - Z = Z ./ den - Y = Y .* tmp - Y = Y ./ den.^3 - dY = Z-Y + tmp = mean(Yout.*Zout,this.doNorm) + Zout ./= den + Yout .*= tmp + Yout ./= den.^3 # TODO: look into doing both this division and multiplication above at same time + dY = Zout-Yout dYa = mean(dY,this.doNorm) - dY = dY .- dYa + dY .-= dYa return reshape(dY,:,nex) end diff --git a/src/layers/singleLayer.jl b/src/layers/singleLayer.jl index 28fa27e..875955e 100644 --- a/src/layers/singleLayer.jl +++ b/src/layers/singleLayer.jl @@ -1,17 +1,16 @@ export singleLayer,getSingleLayer -mutable struct singleLayer{T} <: AbstractMeganetElement{T} +mutable struct singleLayer{T, TK <: AbstractConvKernel{T}, TN <: Union{batchNormNN{T}, normLayer{T}}} <: AbstractMeganetElement{T} activation :: Function # activation function - K # transformation type - nLayer :: Union{NN{T}, normLayer{T}, AffineScalingLayer{T}} # normalization layer + K :: TK # transformation type + nLayer :: TN # normalization layer Bin :: Array{T} # bias inside nonlinearity Bout :: Array{T} # bias outside nonlinearity - # singleLayer{T}(K,nLayer;Bin=zeros(T,nFeatOut(K),0),Bout=zeros(T,nFeatOut(K),0),activation=tanhActivation) = end function getSingleLayer(TYPE::Type, K,nLayer;Bin=zeros(TYPE,nFeatOut(K),0),Bout=zeros(TYPE,nFeatOut(K),0),activation=tanhActivation) - singleLayer{TYPE}(activation,K,nLayer,Bin,Bout); + singleLayer(activation,K,nLayer,Bin,Bout); end @@ -28,17 +27,19 @@ function splitWeights(this::singleLayer{T},theta::Array{T}) where {T <: Number} return th1, th2, th3, th4 end -function apply(this::singleLayer{T},theta::Array{T},Y::Array{T},doDerivative=false) where {T <: Number} +function apply(this::singleLayer{T},theta::Array{T},Yin::Array{T},doDerivative=false) where {T <: Number} tmp = Array{Any}(2) - nex = div(length(Y),nFeatIn(this)) - Y = reshape(Y,:,nex) + nex = div(length(Yin),nFeatIn(this)) + Y = reshape(Yin,:,nex) th1,th2,th3,th4 = splitWeights(this,theta) - Y = getOp(this.K,th1)*Y .+ this.Bin * th2 - Y,dummy,tmp[1] = apply(this.nLayer,th4,Y,doDerivative) - Y,tmp[2] = this.activation(Y,doDerivative) - Y = Y .+ this.Bout*th3 - Ydata = Y - return Ydata, Y, tmp + + Yout::Array{T,2} = getOp(this.K,th1)*Y + Yout .+= this.Bin * th2 + Yout,dummy,tmp[1] = apply(this.nLayer,th4,Yout,doDerivative) + Yout,tmp[2] = this.activation(Yout,doDerivative) + Yout .+= this.Bout*th3 + Ydata = Yout + return Ydata, Yout, tmp end function nTheta(this::singleLayer) @@ -62,76 +63,79 @@ function initTheta(this::singleLayer{T}) where {T <: Number} end -function Jthetamv(this::singleLayer{T},dtheta::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - dA = tmp[2] - nex = div(length(Y),nFeatIn(this)) - Y = reshape(Y,:,nex) +function Jthetamv(this::singleLayer{T},dtheta::Array{T},theta::Array{T},Yin::Array{T},tmp) where {T <: Number} + dA::Array{T,2} = tmp[2] + nex = div(length(Yin),nFeatIn(this)) + Y = reshape(Yin,:,nex) th1,th2,th3,th4 = splitWeights(this,theta) dth1,dth2,dth3,dth4 = splitWeights(this,dtheta) - dZ = Jthetamv(this.K,dth1,th1,Y) .+ this.Bin*dth2 + dZ::Array{T,2} = Jthetamv(this.K,dth1,th1,Y) .+ this.Bin*dth2 Kop = getOp(this.K,th1) dZ = Jmv(this.nLayer,dth4,dZ,th4,Kop*Y.+this.Bin*th2,tmp[1])[2] - dZ = dA.*dZ .+ this.Bout*dth3; + dZ .*= dA + dZ .+= this.Bout*dth3 return dZ, dZ end -function JYmv(this::singleLayer{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} +function JYmv(this::singleLayer{T},dYin::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} dA = tmp[2] - nex = div(length(dY),nFeatIn(this)) + nex = div(length(dYin),nFeatIn(this)) th1,th2,th3,th4 = splitWeights(this,theta) Kop = getOp(this.K,th1) - dY = reshape(dY,:,nex) + dY = reshape(dYin,:,nex) dZ = Kop*dY dZ = JYmv(this.nLayer,dZ,th4,Kop*Y.+this.Bin*th2,tmp[1])[2] - dZ = dA.*dZ + # dZ = dA.*dZ + dZ .*= dA return dZ,dZ end -function Jmv(this::singleLayer{T},dtheta::Array{T},dY::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - dA = tmp[2] - nex = div(length(Y),nFeatIn(this)) +function Jmv(this::singleLayer{T},dtheta::Array{T},dYin::Array{T},theta::Array{T},Yin::Array{T},tmp) where {T <: Number} + dA::Array{T,2} = tmp[2] + nex = div(length(Yin),nFeatIn(this)) th1,th2,th3,th4 = splitWeights(this,theta) dth1,dth2,dth3,dth4 = splitWeights(this,dtheta) - dY = reshape(dY,:,nex); + dY = reshape(dYin,:,nex); Kop = getOp(this.K,th1) - dZ = Kop*dY; + dZ::Array{T, 2} = Kop*dY; - Y = reshape(Y,:,nex); + Y = reshape(Yin,:,nex); dZ += Jthetamv(this.K,dth1,th1,Y) .+ this.Bin*dth2 dZ = Jmv(this.nLayer,dth4,dZ,th4,Kop*Y.+this.Bin*th2,tmp[1])[2] - dZ = dA.*dZ .+ this.Bout*dth3 + dZ .*= dA + dZ .+= this.Bout*dth3 return dZ,dZ end -function JTmv(this::singleLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - dA = tmp[2] +function JTmv(this::singleLayer{T},Zin::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + dA::Array{T,2} = tmp[2] nex = div(length(Y),nFeatIn(this)) - Z = reshape(Z,:,nex) + Z = reshape(Zin,:,nex) th1,th2,th3,th4 = splitWeights(this,theta) Kop = getOp(this.K,th1) dth3 = vec(sum(this.Bout'*Z,2)) dAZ = dA.*Z - dth4,dAZ = JTmv(this.nLayer,dAZ,zeros(T,0),th4,Kop*Y.+this.Bin*th2,tmp[1]) + dth4,dAZ = JTmv(this.nLayer,dAZ,zeros(T,0),th4,Kop*Y.+this.Bin*th2,tmp[1]) # this not type stable dth2 = vec(sum(this.Bin'*reshape(dAZ,:,nex),2)) - dth1 = JthetaTmv(this.K, dAZ,theta,Y) + dth1 = JthetaTmv(this.K, dAZ,theta,Y) # this not type stable dY = Kop'*reshape(dAZ,:,nex) dtheta = [vec(dth1); vec(dth2); vec(dth3); vec(dth4)] - + return dtheta, dY end -function JthetaTmv(this::singleLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} +function JthetaTmv(this::singleLayer{T},Zin::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} dA = tmp[2] - nex = div(length(Z),nFeatOut(this)) + nex = div(length(Zin),nFeatOut(this)) th1,th2,th3,th4 = splitWeights(this,theta) - Z = reshape(Z,:,nex); + Z = reshape(Zin,:,nex); dAZ = dA.*Z; dth3 = vec(sum(this.Bout'*Z,2)); Kop = getOp(this.K,th1) @@ -141,13 +145,14 @@ function JthetaTmv(this::singleLayer{T},Z::Array{T},dummy::Array{T},theta::Array return [vec(dth1); vec(dth2); vec(dth3); vec(dth4)]; end -function JYTmv(this::singleLayer{T},Z::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} - dA = tmp[2] +function JYTmv(this::singleLayer{T},Zin::Array{T},dummy::Array{T},theta::Array{T},Y::Array{T},tmp) where {T <: Number} + dA::Array{T,2} = tmp[2] nex = div(length(Y),nFeatIn(this)) th1,th2,th3,th4 = splitWeights(this,theta) Kop = getOp(this.K,th1) - Z = reshape(Z,:,nex) - dAZ = dA.*Z + Z = reshape(Zin,:,nex) + dAZ::Array{T,2} = dA.*Z dAZ = JYTmv(this.nLayer,dAZ,(T)[],th4,Kop*Y.+this.Bin*th2,tmp[1]) - return Kop'*reshape(dAZ,:,nex) + ret::Array{T,2} = Kop'*reshape(dAZ,:,nex) + return ret #TODO: @lars or eldad rename this variable as I'm not sure what to call it end diff --git a/src/optimization/sgd.jl b/src/optimization/sgd.jl index a1b31e6..a2feab5 100644 --- a/src/optimization/sgd.jl +++ b/src/optimization/sgd.jl @@ -72,7 +72,7 @@ function solve(this::SGD{T},objFun::dnnObjFctn,xc::Array{T},Y::Array{T},C::Array Jval,pVal = getMisfit(objFun,xc,Yv,Cv,false); if this.out; - @printf "%d\t%1.2e\t%1.2f\t%1.2e\t%1.2e\t%1.2f\n" epoch Jc 100*(1-para[3]/para[2]) norm(xOld-xc) Jval 100*(1-pVal[3]/para[2]) + @printf "%d\t%1.2e\t%1.2f\t%1.2e\t%1.2e\t%1.2f\n" epoch Jc 100*(1-para[3]/para[2]) norm(xOld-xc) Jval 100*(1-pVal[3]/pVal[2]) end xOld = copy(xc); diff --git a/src/utils/Benchmark.jl b/src/utils/Benchmark.jl index e207182..7b13246 100644 --- a/src/utils/Benchmark.jl +++ b/src/utils/Benchmark.jl @@ -11,21 +11,28 @@ function Benchmark(trial::BenchmarkTools.Trial) end """ - Use: updatehistory!(history::String, trial::BenchmarkTools.Trial; pkg::Module = Meganet) + Use: updatehistory!(history::String, trial::BenchmarkTools.Trial, funcName::String; pkg::Module = Meganet) Appends `hist` in the JLD file `history` with the latest trial and metadata contained in a `Benchmark` instance. """ -function updatehistory!(history::String, trial::BenchmarkTools.Trial; pkg::Module = Meganet) +function updatehistory!(history::String, trial::BenchmarkTools.Trial, funcName::String; pkg::Module = Meganet) cd(Pkg.dir("$pkg")) if isfile(history) println("Appending trial history: "*history) - hist = JLD.load(history, "hist") - push!(hist, Meganet.Benchmark(trial)) + + hist = JLD.load(history) + if haskey(hist, funcName) + histFunc = hist[funcName] + else + histFunc = Vector{Meganet.Benchmark}() + end + + push!(histFunc, Meganet.Benchmark(trial)) JLD.jldopen(history, "w") do file - write(file, "hist", hist) + write(file, funcName, histFunc) end else println("Creating trial history: "*history) @@ -33,7 +40,7 @@ function updatehistory!(history::String, trial::BenchmarkTools.Trial; pkg::Modul push!(hist, Meganet.Benchmark(trial)) JLD.jldopen(history, "w") do file - write(file, "hist", hist) + write(file, funcName, hist) end end end diff --git a/src/utils/testAbstractMeganetElement.jl b/src/utils/testAbstractMeganetElement.jl index 4dee293..8b38a2b 100644 --- a/src/utils/testAbstractMeganetElement.jl +++ b/src/utils/testAbstractMeganetElement.jl @@ -7,6 +7,7 @@ function testAbstractMeganetElement(L::AbstractMeganetElement{T};out::Bool=false @testset "features immutable" begin theta = initTheta(L) + theta .+= .1 # To test if Y changes for affineScalingLayer Y = randn(T,nFeatIn(L),nex) Yo = copy(Y) Zd,Z,tmp = apply(L,theta,Y,true) diff --git a/test/kernel/convGEMMKernelTest.jl b/test/kernel/convGEMMKernelTest.jl index efc8727..8b16941 100644 --- a/test/kernel/convGEMMKernelTest.jl +++ b/test/kernel/convGEMMKernelTest.jl @@ -32,4 +32,24 @@ for TYPE=[Float64,Float32] # println("derivativeTest t1=$t1\t t2=$t2") @test norm(t1-t2)/norm(t2) < 1e3*eps(TYPE) end + + @testset "new derivitive test" begin + nImage = [16,16]; + sK = [3,3,2,4]; + K = randn(TYPE,tuple(sK...)); + Y = randn(TYPE,nImage[1],nImage[2],sK[3],2); + Z = randn(TYPE,nImage[1],nImage[2],sK[4],2); + Kernel2 = getConvGEMMKernel(TYPE,nImage,sK); + AY = Amv(Kernel2,K,Y); + ATZ = ATmv(Kernel2,K,Z); + + v1 = vecdot(Z,AY); + v2 = vecdot(ATZ,Y); + + v3 = vecdot(Z,Jthetamv(Kernel2,K,(TYPE)[],Y)); + v4 = vecdot(K,JthetaTmv(Kernel2,Z,(TYPE)[],Y)); + @test norm(v1-v2)/norm(v2) < 1e3*eps(TYPE) && + norm(v2-v3)/norm(v3) < 1e3*eps(TYPE) && + norm(v3-v4)/norm(v4) < 1e3*eps(TYPE) + end end diff --git a/test/layer/singleLayerTest.jl b/test/layer/singleLayerTest.jl index ce94ee4..cf009ef 100644 --- a/test/layer/singleLayerTest.jl +++ b/test/layer/singleLayerTest.jl @@ -24,7 +24,7 @@ for TYPE=[Float64,Float32] nex = 4 K = getSparseConvKernel2D(TYPE,nImg,[3,3,1,nc]) Bin = randn(TYPE, nFeatOut(K),4) - nLayer = getBatchNormLayer(TYPE,[prod(nImg),nc],isTrainable=true).layers[2] + nLayer = getBatchNormLayer(TYPE,[prod(nImg),nc],isTrainable=true) #Do we need this to be .layers[2]? L = getSingleLayer(TYPE,K,nLayer,Bin=Bin) @testset "singleLayer (conv/Batch/not trainable) $TYPE" begin testAbstractMeganetElement(L,nex=nex)