From a23001210232c91beec9643e444c08b3c15bcf34 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Mon, 25 Mar 2024 11:16:36 -0400 Subject: [PATCH 1/4] Added common examples --- examples/common/idxio.py | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 examples/common/idxio.py diff --git a/examples/common/idxio.py b/examples/common/idxio.py new file mode 100644 index 0000000..c891b99 --- /dev/null +++ b/examples/common/idxio.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +####################################################### +# Copyright (c) 2024, ArrayFire +# All rights reserved. +# +# This file is distributed under 3-clause BSD license. +# The complete license agreement can be obtained at: +# http://arrayfire.com/licenses/BSD-3-Clause +######################################################## + +def reverse_char(b): + b = (b & 0xF0) >> 4 | (b & 0x0F) << 4 + b = (b & 0xCC) >> 2 | (b & 0x33) << 2 + b = (b & 0xAA) >> 1 | (b & 0x55) << 1 + return b + + +# http://stackoverflow.com/a/9144870/2192361 +def reverse(x): + x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1) + x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2) + x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4) + x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8) + x = ((x >> 16) & 0xffff) | ((x & 0xffff) << 16); + return x + + +def read_idx(name): + with open(name, 'rb') as f: + # In the C++ version, bytes the size of 4 chars are being read + # May not work properly in machines where a char is not 1 byte + bytes_read = f.read(4) + bytes_read = bytearray(bytes_read) + + if bytes_read[2] != 8: + raise RuntimeError('Unsupported data type') + + numdims = bytes_read[3] + elemsize = 1 + + # Read the dimensions + elem = 1 + dims = [0] * numdims + for i in range(numdims): + bytes_read = bytearray(f.read(4)) + + # Big endian to little endian + for j in range(4): + bytes_read[j] = reverse_char(bytes_read[j]) + bytes_read_int = int.from_bytes(bytes_read, 'little') + dim = reverse(bytes_read_int) + + elem = elem * dim; + dims[i] = dim; + + # Read the data + cdata = f.read(elem * elemsize) + cdata = list(cdata) + data = [float(cdata_elem) for cdata_elem in cdata] + + return (dims, data) + +if __name__ == '__main__': + # Example usage of reverse_char + byte_value = 0b10101010 + reversed_byte = reverse_char(byte_value) + print(f"Original byte: {byte_value:08b}, Reversed byte: {reversed_byte:08b}") + + # Example usage of reverse + int_value = 0x12345678 + reversed_int = reverse(int_value) + print(f"Original int: {int_value:032b}, Reversed int: {reversed_int:032b}") \ No newline at end of file From 8de08863c9446cf2cee6c9286a09aa4eacda0624 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Mon, 25 Mar 2024 23:22:02 -0400 Subject: [PATCH 2/4] added common example --- examples/common/idxio.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/common/idxio.py b/examples/common/idxio.py index c891b99..dd06855 100644 --- a/examples/common/idxio.py +++ b/examples/common/idxio.py @@ -9,6 +9,7 @@ # http://arrayfire.com/licenses/BSD-3-Clause ######################################################## + def reverse_char(b): b = (b & 0xF0) >> 4 | (b & 0x0F) << 4 b = (b & 0xCC) >> 2 | (b & 0x33) << 2 @@ -18,23 +19,23 @@ def reverse_char(b): # http://stackoverflow.com/a/9144870/2192361 def reverse(x): - x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1) - x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2) - x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4) - x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8) - x = ((x >> 16) & 0xffff) | ((x & 0xffff) << 16); + x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1) + x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2) + x = ((x >> 4) & 0x0F0F0F0F) | ((x & 0x0F0F0F0F) << 4) + x = ((x >> 8) & 0x00FF00FF) | ((x & 0x00FF00FF) << 8) + x = ((x >> 16) & 0xFFFF) | ((x & 0xFFFF) << 16) return x def read_idx(name): - with open(name, 'rb') as f: + with open(name, "rb") as f: # In the C++ version, bytes the size of 4 chars are being read # May not work properly in machines where a char is not 1 byte bytes_read = f.read(4) bytes_read = bytearray(bytes_read) if bytes_read[2] != 8: - raise RuntimeError('Unsupported data type') + raise RuntimeError("Unsupported data type") numdims = bytes_read[3] elemsize = 1 @@ -48,11 +49,11 @@ def read_idx(name): # Big endian to little endian for j in range(4): bytes_read[j] = reverse_char(bytes_read[j]) - bytes_read_int = int.from_bytes(bytes_read, 'little') + bytes_read_int = int.from_bytes(bytes_read, "little") dim = reverse(bytes_read_int) - elem = elem * dim; - dims[i] = dim; + elem = elem * dim + dims[i] = dim # Read the data cdata = f.read(elem * elemsize) @@ -61,13 +62,14 @@ def read_idx(name): return (dims, data) -if __name__ == '__main__': + +if __name__ == "__main__": # Example usage of reverse_char - byte_value = 0b10101010 + byte_value = 0b10101010 reversed_byte = reverse_char(byte_value) print(f"Original byte: {byte_value:08b}, Reversed byte: {reversed_byte:08b}") # Example usage of reverse - int_value = 0x12345678 + int_value = 0x12345678 reversed_int = reverse(int_value) - print(f"Original int: {int_value:032b}, Reversed int: {reversed_int:032b}") \ No newline at end of file + print(f"Original int: {int_value:032b}, Reversed int: {reversed_int:032b}") From ae0d0a07e4541b9df5df43622310a49ad4a88c4d Mon Sep 17 00:00:00 2001 From: sakchal Date: Mon, 22 Apr 2024 09:48:14 -0400 Subject: [PATCH 3/4] rebased changes for machine_learning_examples branch, resolved marge conflicts --- arrayfire/__init__.py | 1 + arrayfire/library/array_functions.py | 15 ++ .../machine_learning/logistic_regression.py | 203 ++++++++++++++++++ examples/machine_learning/mnist_common.py | 106 +++++++++ 4 files changed, 325 insertions(+) create mode 100644 examples/machine_learning/logistic_regression.py create mode 100644 examples/machine_learning/mnist_common.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index e5f2079..b26e78c 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -124,6 +124,7 @@ tile, transpose, upper, + lookup ) __all__ += ["gloh", "orb", "sift", "dog", "fast", "harris", "susan", "hamming_matcher", "nearest_neighbour"] diff --git a/arrayfire/library/array_functions.py b/arrayfire/library/array_functions.py index e8028e5..203c7f9 100644 --- a/arrayfire/library/array_functions.py +++ b/arrayfire/library/array_functions.py @@ -23,7 +23,11 @@ "shift", "tile", "transpose", +<<<<<<< HEAD "lookup", +======= + "lookup" +>>>>>>> d006f32 (Fixed machine learning example, added functionality, fixed sum_function) ] import warnings @@ -1209,6 +1213,7 @@ def transpose(array: Array, /, *, conjugate: bool = False, inplace: bool = False return cast(Array, wrapper.transpose(array.arr, conjugate)) +<<<<<<< HEAD @afarray_as_array def lookup(array: Array, indices: Array, /, *, axis: int = 0) -> Array: @@ -1262,3 +1267,13 @@ def lookup(array: Array, indices: Array, /, *, axis: int = 0) -> Array: - The dimension specified by `axis` must not exceed the number of dimensions in `array`. """ return cast(Array, wrapper.lookup(array.arr, indices.arr, axis)) +======= +@afarray_as_array +def lookup(array: Array, indices: Array, /, dim: int = 0,) -> Array: + if dim >= array.ndim: + raise ValueError(f"Dimension must be < {array.ndim}") + + return cast(Array, wrapper.lookup(array.arr, indices.arr, dim)) + + +>>>>>>> d006f32 (Fixed machine learning example, added functionality, fixed sum_function) diff --git a/examples/machine_learning/logistic_regression.py b/examples/machine_learning/logistic_regression.py new file mode 100644 index 0000000..8092f8a --- /dev/null +++ b/examples/machine_learning/logistic_regression.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python + +####################################################### +# Copyright (c) 2019, ArrayFire +# All rights reserved. +# +# This file is distributed under 3-clause BSD license. +# The complete license agreement can be obtained at: +# http://arrayfire.com/licenses/BSD-3-Clause +######################################################## + +from mnist_common import display_results, setup_mnist + +import sys +import time + +import arrayfire as af + +def accuracy(predicted, target): + _, tlabels = af.imax(target, axis=1) + _, plabels = af.imax(predicted, axis=1) + return 100 * af.count(plabels == tlabels) / tlabels.size + + +def abserr(predicted, target): + return 100 * af.sum(af.abs(predicted - target)) / predicted.size + + +# Predict (probability) based on given parameters +def predict_prob(X, Weights): + Z = af.matmul(X, Weights) + return af.sigmoid(Z) + + +# Predict (log probability) based on given parameters +def predict_log_prob(X, Weights): + return af.log(predict_prob(X, Weights)) + + +# Give most likely class based on given parameters +def predict_class(X, Weights): + probs = predict_prob(X, Weights) + _, classes = af.imax(probs, 1) + return classes + + +def cost(Weights, X, Y, lambda_param=1.0): + # Number of samples + m = Y.shape[0] + + dim0 = Weights.shape[0] + dim1 = Weights.shape[1] if len(Weights.shape) > 1 else 1 + dim2 = Weights.shape[2] if len(Weights.shape) > 2 else 1 + dim3 = Weights.shape[3] if len(Weights.shape) > 3 else 1 + # Make the lambda corresponding to Weights(0) == 0 + lambdat = af.constant(lambda_param, (dim0, dim1, dim2, dim3)) + + # No regularization for bias weights + lambdat[0, :] = 0 + + # Get the prediction + H = predict_prob(X, Weights) + + # Cost of misprediction + Jerr = -1 * af.sum(Y * af.log(H) + (1 - Y) * af.log(1 - H), axis=0) + + # Regularization cost + Jreg = 0.5 * af.sum(lambdat * Weights * Weights, axis=0) + + # Total cost + J = (Jerr + Jreg) / m + + # Find the gradient of cost + D = (H - Y) + dJ = (af.matmul(X, D, af.MatProp.TRANS) + lambdat * Weights) / m + + return J, dJ + + +def train(X, Y, alpha=0.1, lambda_param=1.0, maxerr=0.01, maxiter=1000, verbose=False): + # Initialize parameters to 0 + Weights = af.constant(0, (X.shape[1], Y.shape[1])) + + for i in range(maxiter): + # Get the cost and gradient + J, dJ = cost(Weights, X, Y, lambda_param) + + err = af.max(af.abs(J)) + if err < maxerr: + print('Iteration {0:4d} Err: {1:4f}'.format(i + 1, err)) + print('Training converged') + return Weights + + if verbose and ((i+1) % 10 == 0): + print('Iteration {0:4d} Err: {1:4f}'.format(i + 1, err)) + + # Update the parameters via gradient descent + Weights = Weights - alpha * dJ + + if verbose: + print('Training stopped after {0:d} iterations'.format(maxiter)) + + return Weights + + +def benchmark_logistic_regression(train_feats, train_targets, test_feats): + t0 = time.time() + Weights = train(train_feats, train_targets, 0.1, 1.0, 0.01, 1000) + af.eval(Weights) + af.sync(-1) + t1 = time.time() + dt = t1 - t0 + print('Training time: {0:4.4f} s'.format(dt)) + + t0 = time.time() + iters = 100 + for i in range(iters): + test_outputs = predict_prob(test_feats, Weights) + af.eval(test_outputs) + af.sync(-1) + t1 = time.time() + dt = t1 - t0 + print('Prediction time: {0:4.4f} s'.format(dt / iters)) + + +# Demo of one vs all logistic regression +def logit_demo(console, perc): + # Load mnist data + frac = float(perc) / 100.0 + mnist_data = setup_mnist(frac, True) + num_classes = mnist_data[0] + num_train = mnist_data[1] + num_test = mnist_data[2] + train_images = mnist_data[3] + test_images = mnist_data[4] + train_targets = mnist_data[5] + test_targets = mnist_data[6] + + # Reshape images into feature vectors + feature_length = int(train_images.size / num_train); + train_feats = af.transpose(af.moddims(train_images, (feature_length, num_train))) + + + test_feats = af.transpose(af.moddims(test_images, (feature_length, num_test))) + + train_targets = af.transpose(train_targets) + test_targets = af.transpose(test_targets) + + num_train = train_feats.shape[0] + num_test = test_feats.shape[0] + + + # Add a bias that is always 1 + train_bias = af.constant(1, (num_train, 1)) + test_bias = af.constant(1, (num_test, 1)) + train_feats = af.join(1, train_bias, train_feats) + test_feats = af.join(1, test_bias, test_feats) + + + # Train logistic regression parameters + Weights = train(train_feats, train_targets, + 0.1, # learning rate + 1.0, # regularization constant + 0.01, # max error + 1000, # max iters + True # verbose mode + ) + af.eval(Weights) + af.sync(-1) + + # Predict the results + train_outputs = predict_prob(train_feats, Weights) + test_outputs = predict_prob(test_feats, Weights) + + print('Accuracy on training data: {0:2.2f}'.format(accuracy(train_outputs, train_targets))) + print('Accuracy on testing data: {0:2.2f}'.format(accuracy(test_outputs, test_targets))) + print('Maximum error on testing data: {0:2.2f}'.format(abserr(test_outputs, test_targets))) + + benchmark_logistic_regression(train_feats, train_targets, test_feats) + + if not console: + test_outputs = af.transpose(test_outputs) + # Get 20 random test images + display_results(test_images, test_outputs, af.transpose(test_targets), 20, True) + +def main(): + argc = len(sys.argv) + + device = int(sys.argv[1]) if argc > 1 else 0 + console = sys.argv[2][0] == '-' if argc > 2 else False + perc = int(sys.argv[3]) if argc > 3 else 60 + + + try: + af.set_device(device) + af.info() + logit_demo(console, perc) + except Exception as e: + print('Error: ', str(e)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/machine_learning/mnist_common.py b/examples/machine_learning/mnist_common.py new file mode 100644 index 0000000..63b1dea --- /dev/null +++ b/examples/machine_learning/mnist_common.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python + +####################################################### +# Copyright (c) 2024, ArrayFire +# All rights reserved. +# +# This file is distributed under 3-clause BSD license. +# The complete license agreement can be obtained at: +# http://arrayfire.com/licenses/BSD-3-Clause +######################################################## + +import os +import sys +# sys.path.insert(0, '../common') +from examples.common.idxio import read_idx + +import arrayfire as af + + +def classify(arr, k, expand_labels): + ret_str = '' + if expand_labels: + vec = af.cast(arr[:, k], af.f32) + h_vec = vec.to_list() + data = [] + + for i in range(vec.size): + data.append((h_vec[i], i)) + + data = sorted(data, key=lambda pair: pair[0], reverse=True) + + ret_str = str(data[0][1]) + + else: + ret_str = str(int(af.cast(arr[k], af.float32).scalar())) + + return ret_str + + +def setup_mnist(frac, expand_labels): + root_path = os.path.dirname(os.path.abspath(__file__)) + file_path = root_path + '/../../assets/examples/data/mnist/' + idims, idata = read_idx(file_path + 'images-subset') + ldims, ldata = read_idx(file_path + 'labels-subset') + + idims.reverse() + numdims = len(idims) + images = af.Array(idata, af.float32, tuple(idims)) + + R = af.randu((10000, 1)); + cond = R < min(frac, 0.8) + train_indices = af.where(cond) + test_indices = af.where(~cond) + + train_images = af.lookup(images, train_indices, 2) / 255 + test_images = af.lookup(images, test_indices, 2) / 255 + + + num_classes = 10 + num_train = train_images.shape[2] + num_test = test_images.shape[2] + + + if expand_labels: + train_labels = af.constant(0, (num_classes, num_train)) + test_labels = af.constant(0, (num_classes, num_test)) + + h_train_idx = train_indices.copy() + h_test_idx = test_indices.copy() + + ldata = list(map(int, ldata)) + + for i in range(num_train): + ldata_ind = ldata[h_train_idx[i].scalar()] + train_labels[ldata_ind, i] = 1 + + for i in range(num_test): + ldata_ind = ldata[h_test_idx[i].scalar()] + test_labels[ldata_ind, i] = 1 + + else: + labels = af.Array(idata, af.float32, tuple(idims)) + train_labels = labels[train_indices] + test_labels = labels[test_indices] + + return (num_classes, + num_train, + num_test, + train_images, + test_images, + train_labels, + test_labels) + + +def display_results(test_images, test_output, test_actual, num_display, expand_labels): + for i in range(num_display): + print('Predicted: ', classify(test_output, i, expand_labels)) + print('Actual: ', classify(test_actual, i, expand_labels)) + + img = af.cast((test_images[:, :, i] > 0.1), af.u8) + img = af.moddims(img, (img.size,)).to_list() + for j in range(28): + for k in range(28): + print('\u2588' if img[j * 28 + k] > 0 else ' ', end='') + print() + input() \ No newline at end of file From 9ed3daf6b8093f0cf065a867fde14fbd63b773d4 Mon Sep 17 00:00:00 2001 From: sakchal Date: Mon, 15 Apr 2024 18:00:14 -0400 Subject: [PATCH 4/4] changed data for machine learnin example --- examples/machine_learning/logistic_regression.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/machine_learning/logistic_regression.py b/examples/machine_learning/logistic_regression.py index 8092f8a..da35c8d 100644 --- a/examples/machine_learning/logistic_regression.py +++ b/examples/machine_learning/logistic_regression.py @@ -1,7 +1,7 @@ #!/usr/bin/env python ####################################################### -# Copyright (c) 2019, ArrayFire +# Copyright (c) 2024, ArrayFire # All rights reserved. # # This file is distributed under 3-clause BSD license. @@ -107,7 +107,7 @@ def benchmark_logistic_regression(train_feats, train_targets, test_feats): t0 = time.time() Weights = train(train_feats, train_targets, 0.1, 1.0, 0.01, 1000) af.eval(Weights) - af.sync(-1) + af.sync() t1 = time.time() dt = t1 - t0 print('Training time: {0:4.4f} s'.format(dt)) @@ -117,7 +117,7 @@ def benchmark_logistic_regression(train_feats, train_targets, test_feats): for i in range(iters): test_outputs = predict_prob(test_feats, Weights) af.eval(test_outputs) - af.sync(-1) + af.sync() t1 = time.time() dt = t1 - t0 print('Prediction time: {0:4.4f} s'.format(dt / iters))