From 89828a6fa7315cda2bc4cf05fd8b0ea725a50716 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Thu, 19 Nov 2020 22:01:54 +0100 Subject: [PATCH 01/16] [SYSTEMDS-2550] Test case update --- .../cp/ParamservBuiltinCPInstruction.java | 19 +++- .../paramserv/FederatedParamservTest.java | 30 ++++--- .../functions/federated/paramserv/CNN.dml | 90 ++++++++----------- .../paramserv/FederatedParamservTest.dml | 23 +---- .../functions/federated/paramserv/TwoNN.dml | 73 +++++++-------- 5 files changed, 101 insertions(+), 134 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 7a285f6580c..9cb2f245517 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -56,6 +56,7 @@ import org.apache.sysds.parser.Statement.PSFrequency; import org.apache.sysds.parser.Statement.PSModeType; import org.apache.sysds.parser.Statement.PSScheme; +import org.apache.sysds.parser.Statement.FederatedPSScheme; import org.apache.sysds.parser.Statement.PSUpdateType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.LocalVariableMap; @@ -86,6 +87,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private static final int DEFAULT_BATCH_SIZE = 64; private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.EPOCH; private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS; + private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME = FederatedPSScheme.KEEP_DATA_ON_WORKER; private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL; private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP; @@ -124,11 +126,12 @@ private void runFederated(ExecutionContext ec) { // get inputs PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); + FederatedPSScheme federatedPSScheme = getFederatedScheme(); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); // partition federated data - DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER) + DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme) .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS))); List pFeatures = result.pFeatures; List pLabels = result.pLabels; @@ -141,8 +144,7 @@ private void runFederated(ExecutionContext ec) { // Get the compiled execution context LocalVariableMap newVarsMap = createVarsMap(ec); - // Level of par is 1 because one worker will be launched per task - // TODO: Fix recompilation + // Level of par is -1 so each federated worker can scale to its cpu cores ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, -1, true); // Create workers' execution context List federatedWorkerECs = ParamservUtils.copyExecutionContext(newEC, workerNum); @@ -469,4 +471,15 @@ private PSScheme getScheme() { return scheme; } + private FederatedPSScheme getFederatedScheme() { + FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME; + if (getParameterMap().containsKey(PS_SCHEME)) { + try { + federated_scheme = FederatedPSScheme.valueOf(getParam(PS_SCHEME)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function in federated mode: not support data partition scheme '%s'", getParam(PS_SCHEME))); + } + } + return federated_scheme; + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index cc0af078f2c..f6a6d2ac577 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -54,6 +54,7 @@ public class FederatedParamservTest extends AutomatedTestBase { private final double _eta; private final String _utype; private final String _freq; + private final String _scheme; // parameters @Parameterized.Parameters @@ -61,23 +62,21 @@ public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, - {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"}, - // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"}, - // {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"}, - // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"}, - // {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"}, - // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"}, - {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"}, - // {"CNN", 5, 1000, 200, 2, 0.01, "ASP", "EPOCH"} + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "SHUFFLE"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, + {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "SHUFFLE"} }); } public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, - int epochs, double eta, String utype, String freq) { + int epochs, double eta, String utype, String freq, String scheme) { _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; _examplesPerWorker = examplesPerWorker; @@ -86,6 +85,7 @@ public FederatedParamservTest(String networkType, int numFederatedWorkers, int e _eta = eta; _utype = utype; _freq = freq; + _scheme = scheme; } @Override @@ -131,10 +131,12 @@ private void federatedParamserv(ExecMode mode) { "eta=" + _eta, "utype=" + _utype, "freq=" + _freq, + "scheme=" + _scheme, "network_type=" + _networkType, "channels=" + C, "hin=" + Hin, - "win=" + Win)); + "win=" + Win, + "seed=" + 25)); // for each worker List ports = new ArrayList<>(); diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml index d622c13cbbe..a2196001782 100644 --- a/src/test/scripts/functions/federated/paramserv/CNN.dml +++ b/src/test/scripts/functions/federated/paramserv/CNN.dml @@ -67,8 +67,10 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov */ train = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int C, int Hin, int Win, int epochs, int batch_size, double learning_rate) - return (list[unknown] model_trained) { + int epochs, string utype, string freq, int batch_size, string scheme, double eta, + int C, int Hin, int Win, + int seed = -1) + return (list[unknown] model) { N = nrow(X) K = ncol(y) @@ -84,74 +86,45 @@ train = function(matrix[double] X, matrix[double] y, N3 = 512 # num nodes in affine3 # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) - [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win) - [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N, F1*(Hin/2)*(Win/2)) - [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) - [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3) + [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win) + lseed = ifelse(seed==-1, -1, seed + 1); + [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N, F1*(Hin/2)*(Win/2)) + lseed = ifelse(seed==-1, -1, seed + 2); + [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) + lseed = ifelse(seed==-1, -1, seed + 3); + [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3) W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu # Initialize SGD w/ Nesterov momentum optimizer - learning_rate = learning_rate # learning rate mu = 0.9 # momentum decay = 0.95 # learning rate decay constant vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2) vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3) vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4) + + model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4) + # Regularization lambda = 5e-04 # Create the hyper parameter list - hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) + hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Calculate iterations iters = ceil(N / batch_size) - print_interval = floor(iters / 25) - - print("[+] Starting optimization") - print("[+] Learning rate: " + learning_rate) - print("[+] Batch size: " + batch_size) - print("[+] Iterations per epoch: " + iters + "\n") for (e in 1:epochs) { - print("[+] Starting epoch: " + e) - print("|") for(i in 1:iters) { - # Create the model list - model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4) - # Get next batch beg = ((i-1) * batch_size) %% N + 1 end = min(N, beg + batch_size - 1) X_batch = X[beg:end,] y_batch = y[beg:end,] - gradients_list = gradients(model_list, hyperparams, X_batch, y_batch) - model_updated = aggregation(model_list, hyperparams, gradients_list) - - W1 = as.matrix(model_updated[1]) - W2 = as.matrix(model_updated[2]) - W3 = as.matrix(model_updated[3]) - W4 = as.matrix(model_updated[4]) - b1 = as.matrix(model_updated[5]) - b2 = as.matrix(model_updated[6]) - b3 = as.matrix(model_updated[7]) - b4 = as.matrix(model_updated[8]) - vW1 = as.matrix(model_updated[9]) - vW2 = as.matrix(model_updated[10]) - vW3 = as.matrix(model_updated[11]) - vW4 = as.matrix(model_updated[12]) - vb1 = as.matrix(model_updated[13]) - vb2 = as.matrix(model_updated[14]) - vb3 = as.matrix(model_updated[15]) - vb4 = as.matrix(model_updated[16]) - if((i %% print_interval) == 0) { - print("█") - } + gradients_list = gradients(model, hyperparams, X_batch, y_batch) + model = aggregation(model, hyperparams, gradients_list) } - print("|") } - - model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4) } /* @@ -190,9 +163,10 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int C, int Hin, int Win, int epochs, int workers, - string utype, string freq, int batch_size, string scheme, string mode, double learning_rate) - return (list[unknown] model_trained) { + int epochs, string utype, string freq, int batch_size, string scheme, double eta, + int C, int Hin, int Win, + int seed = -1) + return (list[unknown] model) { N = nrow(X) K = ncol(y) @@ -208,14 +182,17 @@ train_paramserv = function(matrix[double] X, matrix[double] y, N3 = 512 # num nodes in affine3 # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) - [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win) - [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N, F1*(Hin/2)*(Win/2)) - [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) - [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3) + [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win) + lseed = ifelse(seed==-1, -1, seed + 1); + [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N, F1*(Hin/2)*(Win/2)) + lseed = ifelse(seed==-1, -1, seed + 2); + [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) + lseed = ifelse(seed==-1, -1, seed + 3); + [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3) W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu # Initialize SGD w/ Nesterov momentum optimizer - learning_rate = learning_rate # learning rate + learning_rate = eta # learning rate mu = 0.9 # momentum decay = 0.95 # learning rate decay constant vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) @@ -225,12 +202,15 @@ train_paramserv = function(matrix[double] X, matrix[double] y, # Regularization lambda = 5e-04 # Create the model list - model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4) + model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4) # Create the hyper parameter list - params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) + hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Use paramserv function - model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE") + model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, + upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", + utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, + scheme=scheme, hyperparams=hyperparams) } /* diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml index 16c72c4d2e7..9b643ac6c53 100644 --- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml +++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml @@ -31,27 +31,10 @@ labels = federated(addresses=list($y0, $y1), ranges=list(list(0, 0), list($examples_per_worker, $num_labels), list($examples_per_worker, 0), list($examples_per_worker * 2, $num_labels))) -epochs = $epochs -batch_size = $batch_size -learning_rate = $eta -utype = $utype -freq = $freq -network_type = $network_type - -# currently ignored parameters -workers = 1 -scheme = "DISJOINT_CONTIGUOUS" -paramserv_mode = "LOCAL" - -# config for the cnn -channels = $channels -hin = $hin -win = $win - -if(network_type == "TwoNN") { - model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate) +if($network_type == "TwoNN") { + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $eta, $seed) } else { - model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate) + model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $eta, $channels, $hin, $win, $seed) } print(toString(model)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml index 31e889a9245..2e5f9b5bfb2 100644 --- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml +++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml @@ -57,8 +57,9 @@ source("nn/optim/sgd.dml") as sgd */ train = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, int batch_size, double learning_rate) - return (list[unknown] model_trained) { + int epochs, int batch_size, double eta, + int seed = -1) + return (list[unknown] model) { N = nrow(X) # num examples D = ncol(X) # num features @@ -66,53 +67,31 @@ train = function(matrix[double] X, matrix[double] y, # Create the network: ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax - [W1, b1] = affine::init(D, 200, -1) - [W2, b2] = affine::init(200, 200, -1) - [W3, b3] = affine::init(200, K, -1) + [W1, b1] = affine::init(D, 200, seed = seed) + lseed = ifelse(seed==-1, -1, seed + 1); + [W2, b2] = affine::init(200, 200, seed = lseed) + lseed = ifelse(seed==-1, -1, seed + 2); + [W3, b3] = affine::init(200, K, seed = lseed) W3 = W3 / sqrt(2) # different initialization, since being fed into softmax, instead of relu + model = list(W1, W2, W3, b1, b2, b3) # Create the hyper parameter list - hyperparams = list(learning_rate=learning_rate) + hyperparams = list(learning_rate=eta) # Calculate iterations iters = ceil(N / batch_size) - print_interval = floor(iters / 25) - - print("[+] Starting optimization") - print("[+] Learning rate: " + learning_rate) - print("[+] Batch size: " + batch_size) - print("[+] Iterations per epoch: " + iters + "\n") for (e in 1:epochs) { - print("[+] Starting epoch: " + e) - print("|") for(i in 1:iters) { - # Create the model list - model_list = list(W1, W2, W3, b1, b2, b3) - # Get next batch beg = ((i-1) * batch_size) %% N + 1 end = min(N, beg + batch_size - 1) X_batch = X[beg:end,] y_batch = y[beg:end,] - gradients_list = gradients(model_list, hyperparams, X_batch, y_batch) - model_updated = aggregation(model_list, hyperparams, gradients_list) - - W1 = as.matrix(model_updated[1]) - W2 = as.matrix(model_updated[2]) - W3 = as.matrix(model_updated[3]) - b1 = as.matrix(model_updated[4]) - b2 = as.matrix(model_updated[5]) - b3 = as.matrix(model_updated[6]) - - if((i %% print_interval) == 0) { - print("█") - } + gradients_list = gradients(model, hyperparams, X_batch, y_batch) + model = aggregation(model, hyperparams, gradients_list) } - print("|") } - - model_trained = list(W1, W2, W3, b1, b2, b3) } /* @@ -146,9 +125,9 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, int workers, - string utype, string freq, int batch_size, string scheme, string mode, double learning_rate) - return (list[unknown] model_trained) { + int epochs, string utype, string freq, int batch_size, string scheme, double eta, + int seed = -1) + return (list[unknown] model) { N = nrow(X) # num examples D = ncol(X) # num features @@ -156,16 +135,26 @@ train_paramserv = function(matrix[double] X, matrix[double] y, # Create the network: ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax - [W1, b1] = affine::init(D, 200, -1) - [W2, b2] = affine::init(200, 200, -1) - [W3, b3] = affine::init(200, K, -1) + [W1, b1] = affine::init(D, 200, seed = seed) + lseed = ifelse(seed==-1, -1, seed + 1); + [W2, b2] = affine::init(200, 200, seed = lseed) + lseed = ifelse(seed==-1, -1, seed + 2); + [W3, b3] = affine::init(200, K, seed = lseed) + # W3 = W3 / sqrt(2) # different initialization, since being fed into softmax, instead of relu + + # [W1, b1] = affine::init(D, 200) + # [W2, b2] = affine::init(200, 200) + # [W3, b3] = affine::init(200, K) # Create the model list - model_list = list(W1, W2, W3, b1, b2, b3) + model = list(W1, W2, W3, b1, b2, b3) # Create the hyper parameter list - params = list(learning_rate=learning_rate) + hyperparams = list(learning_rate=eta) # Use paramserv function - model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE") + model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, + upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", + utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, + scheme=scheme, hyperparams=hyperparams) } /* From 34c9341a359aaac4ed3f3612d0bb8ed06b8a1dc1 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Sat, 21 Nov 2020 11:58:44 +0100 Subject: [PATCH 02/16] [SYSTEMDS-2550] implemented shuffle partitioner --- .../runtime/controlprogram/ProgramBlock.java | 2 +- .../federated/FederationMap.java | 2 +- .../paramserv/FederatedPSControlThread.java | 10 +-- .../paramserv/dp/ShuffleFederatedScheme.java | 66 +++++++++++++++++++ .../fed/MatrixIndexingFEDInstruction.java | 2 +- .../fed/VariableFEDInstruction.java | 4 +- .../runtime/io/ReaderWriterFederated.java | 2 +- .../paramserv/FederatedParamservTest.java | 10 +-- 8 files changed, 79 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index 263ecf45ee4..e6e305880bc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -376,7 +376,7 @@ private static void checkFederated( Instruction lastInst, LocalVariableMap vars CacheableData mo = (CacheableData)dat; if( mo.isFederated() ) { - if( mo.getFedMapping().getFedMapping().isEmpty() ) + if( mo.getFedMapping().getFRangeFDataMap().isEmpty() ) throw new DMLRuntimeException("Invalid empty FederationMap for: "+mo); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 95905103f58..45ddea6efae 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -114,7 +114,7 @@ public FederatedRange[] getFederatedRanges() { return _fedMap.keySet().toArray(new FederatedRange[0]); } - public Map getFedMapping() { + public Map getFRangeFDataMap() { return _fedMap; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 80418acc09e..7ca402ae238 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -72,14 +72,8 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque */ public void setup() { // prepare features and labels - _features.getFedMapping().forEachParallel((range, data) -> { - _featuresData = data; - return null; - }); - _labels.getFedMapping().forEachParallel((range, data) -> { - _labelsData = data; - return null; - }); + _featuresData = (FederatedData) _features.getFedMapping().getFRangeFDataMap().values().toArray()[0]; + _labelsData = (FederatedData) _labels.getFedMapping().getFRangeFDataMap().values().toArray()[0]; // calculate number of batches and get data size long dataSize = _features.getNumRows(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index d6d8cfcbf70..dc123681feb 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -19,15 +19,81 @@ package org.apache.sysds.runtime.controlprogram.paramserv.dp; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import java.util.List; +import java.util.concurrent.Future; public class ShuffleFederatedScheme extends DataPartitionFederatedScheme { @Override public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); + + for(int i = 0; i < pFeatures.size(); i++) { + // Works, because the map contains a single entry + FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + + Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}))); + + try { + FederatedResponse response = udfResponse.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle UDF returned fail"); + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing shuffle UDF failed" + e.getMessage()); + } + } + return new Result(pFeatures, pLabels, pFeatures.size()); } + + /** + * Shuffle UDF executed on the federated worker + */ + private static class shuffleDataOnFederatedWorker extends FederatedUDF { + protected shuffleDataOnFederatedWorker(long[] inIDs) { + super(inIDs); + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixObject features = (MatrixObject) data[0]; + MatrixObject labels = (MatrixObject) data[1]; + + // generate permutation matrix + MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + + // matrix multiplies + features.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, features.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + features.release(); + + labels.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, labels.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + labels.release(); + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java index 477379c493f..9418d9ceaf3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java @@ -61,7 +61,7 @@ private void rightIndexing(ExecutionContext ec) //modify federated ranges in place Map ixs = new HashMap<>(); - for(FederatedRange range : fedMap.getFedMapping().keySet()) { + for(FederatedRange range : fedMap.getFRangeFDataMap().keySet()) { long rs = range.getBeginDims()[0], re = range.getEndDims()[0], cs = range.getBeginDims()[1], ce = range.getEndDims()[1]; long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index da25122135e..f2edb0bd08d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -104,7 +104,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { MatrixObject out = ec.getMatrixObject(_in.getOutput()); FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); Map newMap = new HashMap<>(); - for(Map.Entry pair : outMap.getFedMapping().entrySet()) { + for(Map.Entry pair : outMap.getFRangeFDataMap().entrySet()) { FederatedData om = pair.getValue(); FederatedData nf = new FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(), om.getVarID()); @@ -131,7 +131,7 @@ private void processCastAsFrameVariableInstruction(ExecutionContext ec) { out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz()); FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); Map newMap = new HashMap<>(); - for(Map.Entry pair : outMap.getFedMapping().entrySet()) { + for(Map.Entry pair : outMap.getFRangeFDataMap().entrySet()) { FederatedData om = pair.getValue(); FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(), om.getVarID()); diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java index 0527d232f71..694eb4feabc 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java @@ -107,7 +107,7 @@ public static void write(String file, FederationMap fedMap) { FileSystem fs = IOUtilFunctions.getFileSystem(path, job); DataOutputStream out = fs.create(path, true); ObjectMapper mapper = new ObjectMapper(); - FederatedDataAddress[] outObjects = parseMap(fedMap.getFedMapping()); + FederatedDataAddress[] outObjects = parseMap(fedMap.getFRangeFDataMap()); try(BufferedWriter pw = new BufferedWriter(new OutputStreamWriter(out))) { mapper.writeValue(pw, outObjects); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index f6a6d2ac577..3616fc60256 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -63,15 +63,15 @@ public static Collection parameters() { // Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update // type, update frequency {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "SHUFFLE"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "SHUFFLE"} + {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER"} }); } From ef3efda31f70f55e6dfeff22476cc0f0467729e4 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Sun, 22 Nov 2020 14:21:39 +0100 Subject: [PATCH 03/16] [SYSTEMDS-2550] Added federated testing method to AutomatedTestBase --- .../cp/ParamservBuiltinCPInstruction.java | 2 - .../apache/sysds/test/AutomatedTestBase.java | 54 ++++++++++ .../paramserv/FederatedParamservTest.java | 101 ++++++++---------- .../paramserv/FederatedParamservTest.dml | 9 +- 4 files changed, 100 insertions(+), 66 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 9cb2f245517..b12c9fb21ed 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -52,7 +52,6 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.lops.LopProperties; -import org.apache.sysds.parser.Statement; import org.apache.sysds.parser.Statement.PSFrequency; import org.apache.sysds.parser.Statement.PSModeType; import org.apache.sysds.parser.Statement.PSScheme; @@ -98,7 +97,6 @@ public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap @Override public void processInstruction(ExecutionContext ec) { // check if the input is federated - if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() || ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) { runFederated(ec); diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 4f08e88aa4c..9de23f95eda 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -19,6 +19,7 @@ package org.apache.sysds.test; +import static java.lang.Math.ceil; import static java.lang.Thread.sleep; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -27,10 +28,12 @@ import java.io.File; import java.io.IOException; import java.io.PrintStream; +import java.net.InetSocketAddress; import java.net.ServerSocket; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Properties; @@ -43,6 +46,7 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.SparkSession.Builder; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.FileFormat; @@ -52,12 +56,15 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.LopProperties.ExecType; +import org.apache.sysds.lops.compile.Dag; import org.apache.sysds.parser.DataExpression; import org.apache.sysds.parser.ParseException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.DMLScriptException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FrameReader; @@ -67,6 +74,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.privacy.CheckedConstraintsLog; import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; @@ -586,6 +594,52 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ inputDirectories.add(baseDirectory + INPUT_DIR + name); } + protected void federateBalancedAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, + List ports) { + // check matrix non empty + if(matrix.length == 0 || matrix[0].length == 0) + return; + + int nrows = matrix.length; + int ncol = matrix[0].length; + + // create federated MatrixObject + MatrixObject federatedMatrixObject = new MatrixObject(ValueType.FP64, Dag.getNextUniqueVarname(Types.DataType.MATRIX)); + federatedMatrixObject.setMetaData(new MetaDataFormat( + new MatrixCharacteristics(nrows, ncol), + Types.FileFormat.BINARY) + ); + + // write parts balanced and generate FederationMap + HashMap fedHashMap = new HashMap<>(); + double examplesPerWorker = ceil( (double) nrows / (double) numFederatedWorkers); + + for(int i = 0; i < numFederatedWorkers; i++) { + double lowerBound = examplesPerWorker * i; + double upperBound = Math.min(examplesPerWorker * (i + 1), nrows); + double examplesForWorkerI = upperBound - lowerBound; + String path = name + "_" + (i + 1); + + // write slice + writeInputMatrixWithMTD(path, + Arrays.copyOfRange(matrix, (int) lowerBound, (int) upperBound), + false, + new MatrixCharacteristics((long) examplesForWorkerI, ncol, + OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol)); + + // generate fedmap entry + FederatedRange range = new FederatedRange(new long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol}); + FederatedData data = new FederatedData(DataType.MATRIX, new InetSocketAddress(ports.get(i)), input(path)); + fedHashMap.put(range, data); + } + + // TODO: How to generate the ID + federatedMatrixObject.setFedMapping(new FederationMap(1, fedHashMap)); + federatedMatrixObject.getFedMapping().setType(FederationMap.FType.ROW); + + writeInputFederatedWithMTD(name, federatedMatrixObject, new PrivacyConstraint()); + } + /** *

* Adds a matrix to the input path and writes it to a file. diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 3616fc60256..56903381cc0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -27,7 +27,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ExecMode; -import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; @@ -44,11 +43,10 @@ public class FederatedParamservTest extends AutomatedTestBase { private final static String TEST_DIR = "functions/federated/paramserv/"; private final static String TEST_NAME = "FederatedParamservTest"; private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/"; - private final static int _blocksize = 1024; private final String _networkType; private final int _numFederatedWorkers; - private final int _examplesPerWorker; + private final int _dataSetSize; private final int _epochs; private final int _batch_size; private final double _eta; @@ -60,26 +58,26 @@ public class FederatedParamservTest extends AutomatedTestBase { @Parameterized.Parameters public static Collection parameters() { return Arrays.asList(new Object[][] { - // Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update + // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE"}, {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER"} }); } - public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, + public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, int epochs, double eta, String utype, String freq, String scheme) { _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; - _examplesPerWorker = examplesPerWorker; + _dataSetSize = dataSetSize; _batch_size = batch_size; _epochs = epochs; _eta = eta; @@ -117,64 +115,53 @@ private void federatedParamserv(ExecMode mode) { ExecMode platformOld = setExecMode(mode); try { - - // dml name - fullDMLScriptName = HOME + TEST_NAME + ".dml"; - // generate program args - List programArgsList = new ArrayList<>(Arrays.asList("-stats", - "-nvargs", - "examples_per_worker=" + _examplesPerWorker, - "num_features=" + numFeatures, - "num_labels=" + numLabels, - "epochs=" + _epochs, - "batch_size=" + _batch_size, - "eta=" + _eta, - "utype=" + _utype, - "freq=" + _freq, - "scheme=" + _scheme, - "network_type=" + _networkType, - "channels=" + C, - "hin=" + Hin, - "win=" + Win, - "seed=" + 25)); - - // for each worker + // start threads List ports = new ArrayList<>(); List threads = new ArrayList<>(); for(int i = 0; i < _numFederatedWorkers; i++) { - // write row partitioned features to disk - writeInputMatrixWithMTD("X" + i, - generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), - false, - new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, - _examplesPerWorker * numFeatures)); - // write row partitioned labels to disk - writeInputMatrixWithMTD("y" + i, - generateDummyMNISTLabels(_examplesPerWorker, numLabels), - false, - new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, - _examplesPerWorker * numLabels)); - - // start worker ports.add(getRandomAvailablePort()); threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S)); - - // add worker to program args - programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i))); - programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i))); } + + double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win); + double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels); + String featuresName = "X_" + _numFederatedWorkers; + String labelsName = "y_" + _numFederatedWorkers; + + federateBalancedAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports); + federateBalancedAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports); + try { - Thread.sleep(1000); + Thread.sleep(2000); } catch(InterruptedException e) { e.printStackTrace(); } - + + // dml name + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + // generate program args + List programArgsList = new ArrayList<>(Arrays.asList("-stats", + "-nvargs", + "features=" + input(featuresName), + "labels=" + input(labelsName), + "epochs=" + _epochs, + "batch_size=" + _batch_size, + "eta=" + _eta, + "utype=" + _utype, + "freq=" + _freq, + "scheme=" + _scheme, + "network_type=" + _networkType, + "channels=" + C, + "hin=" + Hin, + "win=" + Win, + "seed=" + 25)); + programArgs = programArgsList.toArray(new String[0]); LOG.debug(runTest(null)); Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst()); - // cleanup + // shut down threads for(int i = 0; i < _numFederatedWorkers; i++) { TestUtils.shutdownThreads(threads.get(i)); } diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml index 9b643ac6c53..33d1a0c7b73 100644 --- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml +++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml @@ -23,13 +23,8 @@ source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN # create federated input matrices -features = federated(addresses=list($X0, $X1), - ranges=list(list(0, 0), list($examples_per_worker, $num_features), - list($examples_per_worker, 0), list($examples_per_worker * 2, $num_features))) - -labels = federated(addresses=list($y0, $y1), - ranges=list(list(0, 0), list($examples_per_worker, $num_labels), - list($examples_per_worker, 0), list($examples_per_worker * 2, $num_labels))) +features = read($features) +labels = read($labels) if($network_type == "TwoNN") { model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $eta, $seed) From 9061fb6e3b6e41ebc6cec2bacd365276b79c85a5 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Tue, 24 Nov 2020 11:34:03 +0100 Subject: [PATCH 04/16] [SYSTEMDS-2550] Added suggestions by Sebastian W. --- .../java/org/apache/sysds/test/AutomatedTestBase.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 9de23f95eda..18338418edf 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -66,6 +66,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FrameReader; import org.apache.sysds.runtime.io.FrameReaderFactory; @@ -632,12 +633,11 @@ protected void federateBalancedAndWriteInputMatrixWithMTD(String name, double[][ FederatedData data = new FederatedData(DataType.MATRIX, new InetSocketAddress(ports.get(i)), input(path)); fedHashMap.put(range, data); } - - // TODO: How to generate the ID - federatedMatrixObject.setFedMapping(new FederationMap(1, fedHashMap)); + + federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap)); federatedMatrixObject.getFedMapping().setType(FederationMap.FType.ROW); - writeInputFederatedWithMTD(name, federatedMatrixObject, new PrivacyConstraint()); + writeInputFederatedWithMTD(name, federatedMatrixObject, null); } /** From 3d901f8436318b7b26c0a255d9c4d33c0839549f Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Wed, 25 Nov 2020 14:23:10 +0100 Subject: [PATCH 05/16] [SYSTEMDS-2550] Built replication, subsamling, and balancing partitioner and extendend automated testbase --- .../org/apache/sysds/parser/Statement.java | 2 +- .../paramserv/ParamservUtils.java | 30 ++++++ .../paramserv/dp/BalanceFederatedScheme.java | 92 +++++++++++++++++++ .../dp/DataPartitionFederatedScheme.java | 58 +++++++++++- .../dp/FederatedDataPartitioner.java | 9 ++ .../dp/ReplicateFederatedScheme.java | 89 ++++++++++++++++++ .../paramserv/dp/ShuffleFederatedScheme.java | 26 +----- .../dp/SubsampleFederatedScheme.java | 89 ++++++++++++++++++ .../apache/sysds/test/AutomatedTestBase.java | 22 +++-- .../paramserv/FederatedParamservTest.java | 18 ++-- 10 files changed, 394 insertions(+), 41 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java create mode 100644 src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index b61b0d645d0..bddd91aa9ef 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -95,7 +95,7 @@ public enum PSScheme { DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE } public enum FederatedPSScheme { - KEEP_DATA_ON_WORKER, SHUFFLE + KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE, SUBSAMPLE, BALANCE } public static final String PS_HYPER_PARAMS = "hyperparams"; public static final String PS_CHECKPOINTING = "checkpointing"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java index e63fb141bd3..1ee9b561d8e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java @@ -214,6 +214,36 @@ public static MatrixBlock generatePermutation(int numEntries, long seed) { new MatrixBlock(numEntries, numEntries, true)); } + /** + * Generates a matrix which when left multiplied with the input matrix will subsample + * @param nsamples number of samples + * @param nrows number of rows in input matrix + * @param seed seed used to generate random number + * @return subsample matrix + */ + public static MatrixBlock generateSubsampleMatrix(int nsamples, int nrows, long seed) { + MatrixBlock seq = new MatrixBlock(nsamples, 1, false); + // No replacement to preserve as much of the original data as possible + MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, false, seed); + return seq.ctableSeqOperations(sample, 1.0, + new MatrixBlock(nsamples, nrows, true)); + } + + /** + * Generates a matrix which when left multiplied with the input matrix will replicate n data rows + * @param nsamples number of samples + * @param nrows number of rows in input matrix + * @param seed seed used to generate random number + * @return replication matrix + */ + public static MatrixBlock generateReplicationMatrix(int nsamples, int nrows, long seed) { + MatrixBlock seq = new MatrixBlock(nsamples, 1, false); + // Replacement set to true to provide random replication + MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, true, seed); + return seq.ctableSeqOperations(sample, 1.0, + new MatrixBlock(nsamples, nrows, true)); + } + public static ExecutionContext createExecutionContext(ExecutionContext ec, LocalVariableMap varsMap, String updFunc, String aggFunc, int k) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java new file mode 100644 index 00000000000..e61fa4b88ff --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.dp; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.instructions.cp.Data; + +import java.util.List; +import java.util.concurrent.Future; + +public class BalanceFederatedScheme extends DataPartitionFederatedScheme { + @Override + public Result doPartitioning(MatrixObject features, MatrixObject labels) { + List pFeatures = sliceFederatedMatrix(features); + List pLabels = sliceFederatedMatrix(labels); + + int average_num_rows = (int) pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN); + + for(int i = 0; i < pFeatures.size(); i++) { + // Works, because the map contains a single entry + FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + + Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, average_num_rows))); + + try { + FederatedResponse response = udfResponse.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: balance UDF returned fail"); + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing balance UDF failed" + e.getMessage()); + } + } + return new Result(pFeatures, pLabels, pFeatures.size()); + } + + /** + * Balance UDF executed on the federated worker + */ + private static class balanceDataOnFederatedWorker extends FederatedUDF { + int _average_num_rows; + protected balanceDataOnFederatedWorker(long[] inIDs, int average_num_rows) { + super(inIDs); + _average_num_rows = average_num_rows; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixObject features = (MatrixObject) data[0]; + MatrixObject labels = (MatrixObject) data[1]; + + if(features.getNumRows() > _average_num_rows) { + // subsample down to average + subsampleTo(features, _average_num_rows); + subsampleTo(labels, _average_num_rows); + } + else if(features.getNumRows() < _average_num_rows) { + // replicate up to the average + replicateTo(features, _average_num_rows); + replicateTo(labels, _average_num_rows); + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index 4183372c05c..32fa106b965 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -26,6 +26,12 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; @@ -57,12 +63,10 @@ public Result(List pFeatures, List pLabels, int work */ static List sliceFederatedMatrix(MatrixObject fedMatrix) { if (fedMatrix.isFederated(FederationMap.FType.ROW)) { - List slices = Collections.synchronizedList(new ArrayList<>()); fedMatrix.getFedMapping().forEachParallel((range, data) -> { // Create sliced matrix object MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX)); - // Warning needs MetaDataFormat instead of MetaData slice.setMetaData(new MetaDataFormat( new MatrixCharacteristics(range.getSize(0), range.getSize(1)), Types.FileFormat.BINARY) @@ -85,4 +89,54 @@ static List sliceFederatedMatrix(MatrixObject fedMatrix) { "currently only supports row federated data"); } } + + /** + * Takes a MatrixObjects and shuffles it + * + * @param m the input matrix object + */ + static void shuffle(MatrixObject m) { + // generate permutation matrix + MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); + // matrix multiplies + m.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + m.release(); + } + + /** + * Takes a MatrixObjects and extends it to the chosen number of rows by random replication + * + * @param m the input matrix object + */ + static void replicateTo(MatrixObject m, int rows) { + int num_rows_needed = rows - Math.toIntExact(m.getNumRows()); + // generate replication matrix + MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); + // matrix multiplies and append + MatrixBlock replicatedFeatures = replicateMatrixBlock.aggregateBinaryOperations( + replicateMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))); + + m.acquireModify(m.acquireReadAndRelease().append(replicatedFeatures, new MatrixBlock(), false)); + m.release(); + } + + /** + * Takes a MatrixObjects and shrinks it to the given number of rows by subsampling + * + * @param m the input matrix object + */ + static void subsampleTo(MatrixObject m, int rows) { + // generate subsampling matrix + MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(rows, Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); + // matrix multiplies + m.acquireModify(subsampleMatrixBlock.aggregateBinaryOperations( + subsampleMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + m.release(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java index 4cdfb95f033..c2978669b04 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java @@ -35,6 +35,15 @@ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) { case SHUFFLE: _scheme = new ShuffleFederatedScheme(); break; + case REPLICATE: + _scheme = new ReplicateFederatedScheme(); + break; + case SUBSAMPLE: + _scheme = new SubsampleFederatedScheme(); + break; + case BALANCE: + _scheme = new BalanceFederatedScheme(); + break; default: throw new DMLRuntimeException(String.format("FederatedDataPartitioner: not support data partition scheme '%s'", scheme)); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java new file mode 100644 index 00000000000..97e60173dd6 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.dp; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.instructions.cp.Data; + +import java.util.List; +import java.util.concurrent.Future; + +public class ReplicateFederatedScheme extends DataPartitionFederatedScheme { + @Override + public Result doPartitioning(MatrixObject features, MatrixObject labels) { + List pFeatures = sliceFederatedMatrix(features); + List pLabels = sliceFederatedMatrix(labels); + + int max_rows = 0; + for (MatrixObject pFeature : pFeatures) { + max_rows = (pFeature.getNumRows() > max_rows) ? Math.toIntExact(pFeature.getNumRows()) : max_rows; + } + + for(int i = 0; i < pFeatures.size(); i++) { + // Works, because the map contains a single entry + FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + + Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, max_rows))); + + try { + FederatedResponse response = udfResponse.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: replicate UDF returned fail"); + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: executing replicate UDF failed" + e.getMessage()); + } + } + return new Result(pFeatures, pLabels, pFeatures.size()); + } + + /** + * Replicate UDF executed on the federated worker + */ + private static class replicateDataOnFederatedWorker extends FederatedUDF { + int _max_rows; + protected replicateDataOnFederatedWorker(long[] inIDs, int max_rows) { + super(inIDs); + _max_rows = max_rows; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixObject features = (MatrixObject) data[0]; + MatrixObject labels = (MatrixObject) data[1]; + + // replicate up to the max + if(features.getNumRows() < _max_rows) { + replicateTo(features, _max_rows); + replicateTo(labels, _max_rows); + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index dc123681feb..82908c625e3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -26,13 +26,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import java.util.List; import java.util.concurrent.Future; @@ -60,7 +54,6 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing shuffle UDF failed" + e.getMessage()); } } - return new Result(pFeatures, pLabels, pFeatures.size()); } @@ -76,23 +69,8 @@ protected shuffleDataOnFederatedWorker(long[] inIDs) { public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixObject features = (MatrixObject) data[0]; MatrixObject labels = (MatrixObject) data[1]; - - // generate permutation matrix - MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); - - // matrix multiplies - features.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( - permutationMatrixBlock, features.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - features.release(); - - labels.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( - permutationMatrixBlock, labels.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - labels.release(); - + shuffle(features); + shuffle(labels); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java new file mode 100644 index 00000000000..20c543e2e09 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.dp; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.instructions.cp.Data; + +import java.util.List; +import java.util.concurrent.Future; + +public class SubsampleFederatedScheme extends DataPartitionFederatedScheme { + @Override + public Result doPartitioning(MatrixObject features, MatrixObject labels) { + List pFeatures = sliceFederatedMatrix(features); + List pLabels = sliceFederatedMatrix(labels); + + int min_rows = Integer.MAX_VALUE; + for (MatrixObject pFeature : pFeatures) { + min_rows = (pFeature.getNumRows() < min_rows) ? Math.toIntExact(pFeature.getNumRows()) : min_rows; + } + + for(int i = 0; i < pFeatures.size(); i++) { + // Works, because the map contains a single entry + FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0]; + + Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, min_rows))); + + try { + FederatedResponse response = udfResponse.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: subsample UDF returned fail"); + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: executing subsample UDF failed" + e.getMessage()); + } + } + return new Result(pFeatures, pLabels, pFeatures.size()); + } + + /** + * Subsample UDF executed on the federated worker + */ + private static class subsampleDataOnFederatedWorker extends FederatedUDF { + int _min_rows; + protected subsampleDataOnFederatedWorker(long[] inIDs, int min_rows) { + super(inIDs); + _min_rows = min_rows; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixObject features = (MatrixObject) data[0]; + MatrixObject labels = (MatrixObject) data[1]; + + // subsample down to minimum + if(features.getNumRows() > _min_rows) { + subsampleTo(features, _min_rows); + subsampleTo(labels, _min_rows); + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 18338418edf..8d59ff4f9cd 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -595,8 +595,8 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ inputDirectories.add(baseDirectory + INPUT_DIR + name); } - protected void federateBalancedAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, - List ports) { + protected void federateLocallyAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, + List ports, double[][] addresses) { // check matrix non empty if(matrix.length == 0 || matrix[0].length == 0) return; @@ -611,13 +611,11 @@ protected void federateBalancedAndWriteInputMatrixWithMTD(String name, double[][ Types.FileFormat.BINARY) ); - // write parts balanced and generate FederationMap + // write parts and generate FederationMap HashMap fedHashMap = new HashMap<>(); - double examplesPerWorker = ceil( (double) nrows / (double) numFederatedWorkers); - for(int i = 0; i < numFederatedWorkers; i++) { - double lowerBound = examplesPerWorker * i; - double upperBound = Math.min(examplesPerWorker * (i + 1), nrows); + double lowerBound = addresses[i][0]; + double upperBound = addresses[i][1]; double examplesForWorkerI = upperBound - lowerBound; String path = name + "_" + (i + 1); @@ -640,6 +638,16 @@ protected void federateBalancedAndWriteInputMatrixWithMTD(String name, double[][ writeInputFederatedWithMTD(name, federatedMatrixObject, null); } + protected double[][] generateBalancedFederatedRanges(int numFederatedWorkers, int dataSetSize) { + double[][] addresses = new double[numFederatedWorkers][2]; + double examplesPerWorker = ceil( (double) dataSetSize / (double) numFederatedWorkers); + for(int i = 0; i < numFederatedWorkers; i++) { + addresses[i][0] = examplesPerWorker * i; + addresses[i][1] = Math.min(examplesPerWorker * (i + 1), dataSetSize); + } + return addresses; + } + /** *

* Adds a matrix to the input path and writes it to a file. diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 56903381cc0..c8856c0521d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -60,11 +60,11 @@ public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SHUFFLE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "REPLICATE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "SUBSAMPLE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "BALANCE"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER"}, {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, @@ -97,10 +97,12 @@ public void federatedParamservSingleNode() { federatedParamserv(ExecMode.SINGLE_NODE); } + /* @Test public void federatedParamservHybrid() { federatedParamserv(ExecMode.HYBRID); } + */ private void federatedParamserv(ExecMode mode) { // config @@ -128,8 +130,10 @@ private void federatedParamserv(ExecMode mode) { String featuresName = "X_" + _numFederatedWorkers; String labelsName = "y_" + _numFederatedWorkers; - federateBalancedAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports); - federateBalancedAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports); + federateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, + generateBalancedFederatedRanges(_numFederatedWorkers, features.length)); + federateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, + generateBalancedFederatedRanges(_numFederatedWorkers, labels.length)); try { Thread.sleep(2000); From a9e75f9179929b18c13811c01dfa452fb4b8df84 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Thu, 26 Nov 2020 12:05:56 +0100 Subject: [PATCH 06/16] [SYSTEMDS-2550] Added new functionality to ctableSeqOperation --- .../paramserv/ParamservUtils.java | 8 +-- .../paramserv/dp/BalanceFederatedScheme.java | 25 ++++++--- .../dp/DataPartitionFederatedScheme.java | 25 ++++----- .../dp/ReplicateFederatedScheme.java | 15 +++++- .../paramserv/dp/ShuffleFederatedScheme.java | 9 +++- .../dp/SubsampleFederatedScheme.java | 14 ++++- .../runtime/matrix/data/MatrixBlock.java | 33 ++++++++---- .../apache/sysds/test/AutomatedTestBase.java | 32 ++++++++---- .../paramserv/FederatedParamservTest.java | 51 ++++++++++++------- 9 files changed, 142 insertions(+), 70 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java index 1ee9b561d8e..b20a6cdc88b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java @@ -222,11 +222,11 @@ public static MatrixBlock generatePermutation(int numEntries, long seed) { * @return subsample matrix */ public static MatrixBlock generateSubsampleMatrix(int nsamples, int nrows, long seed) { - MatrixBlock seq = new MatrixBlock(nsamples, 1, false); + MatrixBlock seq = new MatrixBlock(nsamples, nrows, false); // No replacement to preserve as much of the original data as possible MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, false, seed); return seq.ctableSeqOperations(sample, 1.0, - new MatrixBlock(nsamples, nrows, true)); + new MatrixBlock(nsamples, nrows, true), false); } /** @@ -237,11 +237,11 @@ public static MatrixBlock generateSubsampleMatrix(int nsamples, int nrows, long * @return replication matrix */ public static MatrixBlock generateReplicationMatrix(int nsamples, int nrows, long seed) { - MatrixBlock seq = new MatrixBlock(nsamples, 1, false); + MatrixBlock seq = new MatrixBlock(nsamples, nrows, false); // Replacement set to true to provide random replication MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, true, seed); return seq.ctableSeqOperations(sample, 1.0, - new MatrixBlock(nsamples, nrows, true)); + new MatrixBlock(nsamples, nrows, true), false); } public static ExecutionContext createExecutionContext(ExecutionContext ec, diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java index e61fa4b88ff..e04e949be7f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java @@ -27,7 +27,10 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.List; import java.util.concurrent.Future; @@ -38,7 +41,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); - int average_num_rows = (int) pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN); + int average_num_rows = (int) Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN)); for(int i = 0; i < pFeatures.size(); i++) { // Works, because the map contains a single entry @@ -56,6 +59,11 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { catch(Exception e) { throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing balance UDF failed" + e.getMessage()); } + + DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(average_num_rows); + pFeatures.get(i).updateDataCharacteristics(update); + update = pLabels.get(i).getDataCharacteristics().setRows(average_num_rows); + pLabels.get(i).updateDataCharacteristics(update); } return new Result(pFeatures, pLabels, pFeatures.size()); } @@ -76,14 +84,17 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixObject labels = (MatrixObject) data[1]; if(features.getNumRows() > _average_num_rows) { - // subsample down to average - subsampleTo(features, _average_num_rows); - subsampleTo(labels, _average_num_rows); + // generate subsampling matrix + MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + subsampleTo(features, subsampleMatrixBlock); + subsampleTo(labels, subsampleMatrixBlock); } else if(features.getNumRows() < _average_num_rows) { - // replicate up to the average - replicateTo(features, _average_num_rows); - replicateTo(labels, _average_num_rows); + int num_rows_needed = _average_num_rows - Math.toIntExact(features.getNumRows()); + // generate replication matrix + MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + replicateTo(features, replicateMatrixBlock); + replicateTo(labels, replicateMatrixBlock); } return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index 32fa106b965..845804a479b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -91,14 +91,13 @@ static List sliceFederatedMatrix(MatrixObject fedMatrix) { } /** - * Takes a MatrixObjects and shuffles it + * Just a mat multiply used to shuffle with a provided shuffle matrixBlock * * @param m the input matrix object + * @param permutationMatrixBlock the shuffle matrix block */ - static void shuffle(MatrixObject m) { - // generate permutation matrix - MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); - // matrix multiplies + static void shuffle(MatrixObject m, MatrixBlock permutationMatrixBlock) { + // matrix multiply m.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( permutationMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) @@ -111,11 +110,8 @@ static void shuffle(MatrixObject m) { * * @param m the input matrix object */ - static void replicateTo(MatrixObject m, int rows) { - int num_rows_needed = rows - Math.toIntExact(m.getNumRows()); - // generate replication matrix - MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); - // matrix multiplies and append + static void replicateTo(MatrixObject m, MatrixBlock replicateMatrixBlock) { + // matrix multiply and append MatrixBlock replicatedFeatures = replicateMatrixBlock.aggregateBinaryOperations( replicateMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))); @@ -125,14 +121,13 @@ static void replicateTo(MatrixObject m, int rows) { } /** - * Takes a MatrixObjects and shrinks it to the given number of rows by subsampling + * Just a mat multiply used to subsample with a provided subsample matrixBlock * * @param m the input matrix object + * @param subsampleMatrixBlock the subsample matrix block */ - static void subsampleTo(MatrixObject m, int rows) { - // generate subsampling matrix - MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(rows, Math.toIntExact(m.getNumRows()), System.currentTimeMillis()); - // matrix multiplies + static void subsampleTo(MatrixObject m, MatrixBlock subsampleMatrixBlock) { + // matrix multiply m.acquireModify(subsampleMatrixBlock.aggregateBinaryOperations( subsampleMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java index 97e60173dd6..5f5bfa2f3a0 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java @@ -26,7 +26,10 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.List; import java.util.concurrent.Future; @@ -58,6 +61,11 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { catch(Exception e) { throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: executing replicate UDF failed" + e.getMessage()); } + + DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(max_rows); + pFeatures.get(i).updateDataCharacteristics(update); + update = pLabels.get(i).getDataCharacteristics().setRows(max_rows); + pLabels.get(i).updateDataCharacteristics(update); } return new Result(pFeatures, pLabels, pFeatures.size()); } @@ -79,8 +87,11 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // replicate up to the max if(features.getNumRows() < _max_rows) { - replicateTo(features, _max_rows); - replicateTo(labels, _max_rows); + int num_rows_needed = _max_rows - Math.toIntExact(features.getNumRows()); + // generate replication matrix + MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + replicateTo(features, replicateMatrixBlock); + replicateTo(labels, replicateMatrixBlock); } return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index 82908c625e3..5607a5b83e5 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -26,7 +26,9 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import java.util.List; import java.util.concurrent.Future; @@ -69,8 +71,11 @@ protected shuffleDataOnFederatedWorker(long[] inIDs) { public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixObject features = (MatrixObject) data[0]; MatrixObject labels = (MatrixObject) data[1]; - shuffle(features); - shuffle(labels); + + // generate permutation matrix + MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + shuffle(features, permutationMatrixBlock); + shuffle(labels, permutationMatrixBlock); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java index 20c543e2e09..81c0a99823b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java @@ -26,7 +26,10 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.List; import java.util.concurrent.Future; @@ -58,6 +61,11 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { catch(Exception e) { throw new DMLRuntimeException("FederatedDataPartitioner SubsampleFederatedScheme: executing subsample UDF failed" + e.getMessage()); } + + DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(min_rows); + pFeatures.get(i).updateDataCharacteristics(update); + update = pLabels.get(i).getDataCharacteristics().setRows(min_rows); + pLabels.get(i).updateDataCharacteristics(update); } return new Result(pFeatures, pLabels, pFeatures.size()); } @@ -79,8 +87,10 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // subsample down to minimum if(features.getNumRows() > _min_rows) { - subsampleTo(features, _min_rows); - subsampleTo(labels, _min_rows); + // generate subsampling matrix + MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + subsampleTo(features, subsampleMatrixBlock); + subsampleTo(labels, subsampleMatrixBlock); } return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 4cea82779cf..91614101707 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -5326,21 +5326,15 @@ public void ctableOperations(Operator op, MatrixValue thatVal, double scalarThat if( resultBlock!=null ) resultBlock.recomputeNonZeros(); } - + /** - * D = ctable(seq,A,w) - * this <- seq; thatMatrix <- A; thatScalar <- w; result <- D - * - * (i1,j1,v1) from input1 (this) - * (i1,j1,v2) from input2 (that) - * (w) from scalar_input3 (scalarThat2) - * * @param thatMatrix matrix value * @param thatScalar scalar double * @param resultBlock result matrix block + * @param updateClen when this matrix already has the desired number of columns updateClen can be set to false * @return resultBlock */ - public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock) { + public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock, boolean updateClen) { MatrixBlock that = checkType(thatMatrix); CTable ctable = CTable.getCTableFnObject(); double w = thatScalar; @@ -5357,9 +5351,28 @@ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar //update meta data (initially unknown number of columns) //note: nnz maintained in ctable (via quickset) - resultBlock.clen = maxCol; + if(updateClen) { + resultBlock.clen = maxCol; + } return resultBlock; } + + /** + * D = ctable(seq,A,w) + * this <- seq; thatMatrix <- A; thatScalar <- w; result <- D + * + * (i1,j1,v1) from input1 (this) + * (i1,j1,v2) from input2 (that) + * (w) from scalar_input3 (scalarThat2) + * + * @param thatMatrix matrix value + * @param thatScalar scalar double + * @param resultBlock result matrix block + * @return resultBlock + */ + public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock) { + return ctableSeqOperations(thatMatrix, thatScalar, resultBlock, true); + } /** * D = ctable(A,B,W) diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 8d59ff4f9cd..0979deda73d 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -595,8 +595,22 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ inputDirectories.add(baseDirectory + INPUT_DIR + name); } - protected void federateLocallyAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, - List ports, double[][] addresses) { + /** + *

+ * Takes a matrix (double[][]) and writes it in parts locally. Then it creates a federated MatrixObject + * containing the local paths and the given ports. This federated MO is also written to disk with the provided name + * When running federated workers locally on the specified ports this federated Matrix can then be used + * for testing purposes. Just use read on input(name) + *

+ * + * @param name name of the matrix when writing to disk + * @param matrix two dimensional matrix + * @param numFederatedWorkers the number of federated workers + * @param ports a list of port the length of the number of federated workers + * @param ranges an array containing arrays of length to with the upper and lower bound (rows) for the slices + */ + protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, + List ports, double[][] ranges) { // check matrix non empty if(matrix.length == 0 || matrix[0].length == 0) return; @@ -614,8 +628,8 @@ protected void federateLocallyAndWriteInputMatrixWithMTD(String name, double[][] // write parts and generate FederationMap HashMap fedHashMap = new HashMap<>(); for(int i = 0; i < numFederatedWorkers; i++) { - double lowerBound = addresses[i][0]; - double upperBound = addresses[i][1]; + double lowerBound = ranges[i][0]; + double upperBound = ranges[i][1]; double examplesForWorkerI = upperBound - lowerBound; String path = name + "_" + (i + 1); @@ -638,14 +652,14 @@ protected void federateLocallyAndWriteInputMatrixWithMTD(String name, double[][] writeInputFederatedWithMTD(name, federatedMatrixObject, null); } - protected double[][] generateBalancedFederatedRanges(int numFederatedWorkers, int dataSetSize) { - double[][] addresses = new double[numFederatedWorkers][2]; + protected double[][] generateBalancedFederatedRowRanges(int numFederatedWorkers, int dataSetSize) { + double[][] ranges = new double[numFederatedWorkers][2]; double examplesPerWorker = ceil( (double) dataSetSize / (double) numFederatedWorkers); for(int i = 0; i < numFederatedWorkers; i++) { - addresses[i][0] = examplesPerWorker * i; - addresses[i][1] = Math.min(examplesPerWorker * (i + 1), dataSetSize); + ranges[i][0] = examplesPerWorker * i; + ranges[i][1] = Math.min(examplesPerWorker * (i + 1), dataSetSize); } - return addresses; + return ranges; } /** diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index c8856c0521d..0641f3eebec 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -53,6 +53,7 @@ public class FederatedParamservTest extends AutomatedTestBase { private final String _utype; private final String _freq; private final String _scheme; + private final String _data_distribution; // parameters @Parameterized.Parameters @@ -60,21 +61,21 @@ public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "REPLICATE"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "SUBSAMPLE"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "BALANCE"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER"}, - {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE"}, - {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER"} + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE", "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SUBSAMPLE", "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "BALANCE", "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE", "IMBALANCED"}, + /*{"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE", "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, + {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE", "BALANCED"}, + {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "BALANCED"}*/ }); } public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, - int epochs, double eta, String utype, String freq, String scheme) { + int epochs, double eta, String utype, String freq, String scheme, String data_distribution) { _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; _dataSetSize = dataSetSize; @@ -84,6 +85,7 @@ public FederatedParamservTest(String networkType, int numFederatedWorkers, int d _utype = utype; _freq = freq; _scheme = scheme; + _data_distribution = data_distribution; } @Override @@ -111,7 +113,6 @@ private void federatedParamserv(ExecMode mode) { setOutputBuffering(true); int C = 1, Hin = 28, Win = 28; - int numFeatures = C * Hin * Win; int numLabels = 10; ExecMode platformOld = setExecMode(mode); @@ -125,15 +126,27 @@ private void federatedParamserv(ExecMode mode) { threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S)); } + // generate test data double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win); double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels); - String featuresName = "X_" + _numFederatedWorkers; - String labelsName = "y_" + _numFederatedWorkers; - - federateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, - generateBalancedFederatedRanges(_numFederatedWorkers, features.length)); - federateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, - generateBalancedFederatedRanges(_numFederatedWorkers, labels.length)); + String featuresName = ""; + String labelsName = ""; + + // federate test data balanced or imbalanced + if(_data_distribution.equals("IMBALANCED")) { + featuresName = "X_IMBALANCED_" + _numFederatedWorkers; + labelsName = "y_IMBALANCED_" + _numFederatedWorkers; + double[][] ranges = {{0,1}, {1,4}}; + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges); + } + else { + featuresName = "X_BALANCED_" + _numFederatedWorkers; + labelsName = "y_BALANCED_" + _numFederatedWorkers; + double[][] ranges = generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length); + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges); + } try { Thread.sleep(2000); From bd85323993fed64175fcd3ee87014d2276ca3b0f Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Wed, 9 Dec 2020 11:39:12 +0100 Subject: [PATCH 07/16] [SYSTEMDS-2550] Added balancing metrics to the data partitioning --- .../dp/DataPartitionFederatedScheme.java | 95 ++++++++----------- .../dp/KeepDataOnWorkerFederatedScheme.java | 2 +- .../paramserv/dp/ShuffleFederatedScheme.java | 23 ++++- .../cp/ParamservBuiltinCPInstruction.java | 6 +- 4 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index 845804a479b..b75acb6817e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv.dp; +import org.apache.sysds.api.mlcontext.Matrix; import org.apache.sysds.common.Types; import org.apache.sysds.lops.compile.Dag; import org.apache.sysds.runtime.DMLRuntimeException; @@ -26,12 +27,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; @@ -43,14 +38,17 @@ public abstract class DataPartitionFederatedScheme { public static final class Result { - public final List pFeatures; - public final List pLabels; - public final int workerNum; - - public Result(List pFeatures, List pLabels, int workerNum) { - this.pFeatures = pFeatures; - this.pLabels = pLabels; - this.workerNum = workerNum; + public final List _pFeatures; + public final List _pLabels; + public final int _workerNum; + public final BalanceMetrics _balanceMetrics; + + + public Result(List pFeatures, List pLabels, int workerNum, BalanceMetrics balanceMetrics) { + this._pFeatures = pFeatures; + this._pLabels = pLabels; + this._workerNum = workerNum; + this._balanceMetrics = balanceMetrics; } } @@ -63,10 +61,12 @@ public Result(List pFeatures, List pLabels, int work */ static List sliceFederatedMatrix(MatrixObject fedMatrix) { if (fedMatrix.isFederated(FederationMap.FType.ROW)) { + List slices = Collections.synchronizedList(new ArrayList<>()); fedMatrix.getFedMapping().forEachParallel((range, data) -> { // Create sliced matrix object MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX)); + // Warning needs MetaDataFormat instead of MetaData slice.setMetaData(new MetaDataFormat( new MatrixCharacteristics(range.getSize(0), range.getSize(1)), Types.FileFormat.BINARY) @@ -90,48 +90,35 @@ static List sliceFederatedMatrix(MatrixObject fedMatrix) { } } - /** - * Just a mat multiply used to shuffle with a provided shuffle matrixBlock - * - * @param m the input matrix object - * @param permutationMatrixBlock the shuffle matrix block - */ - static void shuffle(MatrixObject m, MatrixBlock permutationMatrixBlock) { - // matrix multiply - m.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( - permutationMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - m.release(); - } + static BalanceMetrics getBalanceMetrics(List slices) { + if (slices == null || slices.size() == 0) + return new BalanceMetrics(0, 0, 0); - /** - * Takes a MatrixObjects and extends it to the chosen number of rows by random replication - * - * @param m the input matrix object - */ - static void replicateTo(MatrixObject m, MatrixBlock replicateMatrixBlock) { - // matrix multiply and append - MatrixBlock replicatedFeatures = replicateMatrixBlock.aggregateBinaryOperations( - replicateMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))); - - m.acquireModify(m.acquireReadAndRelease().append(replicatedFeatures, new MatrixBlock(), false)); - m.release(); + long minRows = slices.get(0).getNumRows(); + long maxRows = minRows; + long sum = 0; + + for (MatrixObject slice : slices) { + if (slice.getNumRows() < minRows) + minRows = slice.getNumRows(); + else if (slice.getNumRows() > maxRows) + maxRows = slice.getNumRows(); + + sum += slice.getNumRows(); + } + + return new BalanceMetrics(minRows, sum / slices.size(), maxRows); } - /** - * Just a mat multiply used to subsample with a provided subsample matrixBlock - * - * @param m the input matrix object - * @param subsampleMatrixBlock the subsample matrix block - */ - static void subsampleTo(MatrixObject m, MatrixBlock subsampleMatrixBlock) { - // matrix multiply - m.acquireModify(subsampleMatrixBlock.aggregateBinaryOperations( - subsampleMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - m.release(); + public static final class BalanceMetrics { + public final long _minRows; + public final long _avgRows; + public final long _maxRows; + + public BalanceMetrics(long minRows, long avgRows, long maxRows) { + this._minRows = minRows; + this._avgRows = avgRows; + this._maxRows = maxRows; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java index 06feded8474..e306f25d29c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java @@ -27,6 +27,6 @@ public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedSchem public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); - return new Result(pFeatures, pLabels, pFeatures.size()); + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index 5607a5b83e5..f5f11e1fdbd 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -27,8 +27,12 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import java.util.List; import java.util.concurrent.Future; @@ -56,7 +60,8 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing shuffle UDF failed" + e.getMessage()); } } - return new Result(pFeatures, pLabels, pFeatures.size()); + + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); } /** @@ -74,8 +79,20 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // generate permutation matrix MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); - shuffle(features, permutationMatrixBlock); - shuffle(labels, permutationMatrixBlock); + + // matrix multiplies + features.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, features.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + features.release(); + + labels.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, labels.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + labels.release(); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index b12c9fb21ed..8a1da5c766c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -131,9 +131,9 @@ private void runFederated(ExecutionContext ec) { // partition federated data DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme) .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS))); - List pFeatures = result.pFeatures; - List pLabels = result.pLabels; - int workerNum = result.workerNum; + List pFeatures = result._pFeatures; + List pLabels = result._pLabels; + int workerNum = result._workerNum; // setup threading BasicThreadFactory factory = new BasicThreadFactory.Builder() From dbbc8aa63abcbe02129911fef7a0debe98af5980 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Thu, 10 Dec 2020 12:41:30 +0100 Subject: [PATCH 08/16] [SYSTEMDS-2550] Added runtime balancing parameter --- ...arameterizedBuiltinFunctionExpression.java | 3 +- .../org/apache/sysds/parser/Statement.java | 4 +++ .../paramserv/FederatedPSControlThread.java | 8 +++-- .../paramserv/dp/BalanceFederatedScheme.java | 3 +- .../dp/ReplicateFederatedScheme.java | 3 +- .../dp/SubsampleFederatedScheme.java | 3 +- .../cp/ParamservBuiltinCPInstruction.java | 33 ++++++++++--------- .../paramserv/FederatedParamservTest.java | 30 ++++++++++------- .../functions/federated/paramserv/CNN.dml | 6 ++-- .../paramserv/FederatedParamservTest.dml | 4 +-- .../functions/federated/paramserv/TwoNN.dml | 6 ++-- 11 files changed, 63 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 2e33676de37..9df1e635382 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -286,7 +286,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) { raiseValidateError("Should provide more arguments for function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS); } //check for invalid parameters - Set valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING); + Set valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING); checkInvalidParameters(getOpCode(), getVarParams(), valid); // check existence and correctness of parameters @@ -304,6 +304,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) { checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional); checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional); checkStringParam(true, fname, Statement.PS_SCHEME, conditional); + checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional); checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional); checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional); diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index bddd91aa9ef..8cf7786229a 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -87,6 +87,10 @@ public boolean isASP() { public enum PSFrequency { BATCH, EPOCH } + public static final String PS_RUNTIME_BALANCING = "runtime_balancing"; + public enum PSRuntimeBalancing { + NONE, CYCLE + } public static final String PS_EPOCHS = "epochs"; public static final String PS_BATCH_SIZE = "batchsize"; public static final String PS_PARALLELISM = "k"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 7ca402ae238..91ac14bc65f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -53,15 +53,17 @@ public class FederatedPSControlThread extends PSWorker implements Callable { private static final long serialVersionUID = 6846648059569648791L; + Statement.PSRuntimeBalancing _runtimeBalancing; FederatedData _featuresData; FederatedData _labelsData; final long _batchCounterVarID; final long _modelVarID; int _totalNumBatches; - public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { super(workerID, updFunc, freq, epochs, batchSize, ec, ps); - + + _runtimeBalancing = runtimeBalancing; // generate the IDs for model and batch counter. These get overwritten on the federated worker each time _batchCounterVarID = FederationUtils.getNextFedDataID(); _modelVarID = FederationUtils.getNextFedDataID(); @@ -245,6 +247,8 @@ public Void call() throws Exception { case EPOCH: computeEpoch(); break; + /*case NBATCH: + break;*/ default: throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq)); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java index e04e949be7f..5cca9383a1e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java @@ -65,7 +65,8 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { update = pLabels.get(i).getDataCharacteristics().setRows(average_num_rows); pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size()); + + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); } /** diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java index 5f5bfa2f3a0..0fdaf25bbbf 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java @@ -67,7 +67,8 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { update = pLabels.get(i).getDataCharacteristics().setRows(max_rows); pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size()); + + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); } /** diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java index 81c0a99823b..e21a813a75c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java @@ -67,7 +67,8 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { update = pLabels.get(i).getDataCharacteristics().setRows(min_rows); pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size()); + + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 8a1da5c766c..ec6903a02cc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -19,20 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN; -import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE; -import static org.apache.sysds.parser.Statement.PS_EPOCHS; -import static org.apache.sysds.parser.Statement.PS_FEATURES; -import static org.apache.sysds.parser.Statement.PS_FREQUENCY; -import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS; -import static org.apache.sysds.parser.Statement.PS_LABELS; -import static org.apache.sysds.parser.Statement.PS_MODE; -import static org.apache.sysds.parser.Statement.PS_MODEL; -import static org.apache.sysds.parser.Statement.PS_PARALLELISM; -import static org.apache.sysds.parser.Statement.PS_SCHEME; -import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN; -import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE; - import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -52,6 +38,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.lops.LopProperties; +import org.apache.sysds.parser.Statement; import org.apache.sysds.parser.Statement.PSFrequency; import org.apache.sysds.parser.Statement.PSModeType; import org.apache.sysds.parser.Statement.PSScheme; @@ -80,12 +67,15 @@ import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.utils.Statistics; +import static org.apache.sysds.parser.Statement.*; + public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName()); private static final int DEFAULT_BATCH_SIZE = 64; private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.EPOCH; private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS; + private static final PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING = PSRuntimeBalancing.NONE; private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME = FederatedPSScheme.KEEP_DATA_ON_WORKER; private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL; private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP; @@ -124,6 +114,7 @@ private void runFederated(ExecutionContext ec) { // get inputs PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); + PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing(); FederatedPSScheme federatedPSScheme = getFederatedScheme(); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); @@ -153,7 +144,7 @@ private void runFederated(ExecutionContext ec) { ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers List threads = IntStream.range(0, workerNum) - .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps)) + .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps)) .collect(Collectors.toList()); if(workerNum != threads.size()) { @@ -379,6 +370,18 @@ private PSFrequency getFrequency() { } } + private PSRuntimeBalancing getRuntimeBalancing() { + if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) { + return DEFAULT_RUNTIME_BALANCING; + } + try { + return PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function: " + + "not support '%s' runtime balancing.", getParam(PS_FREQUENCY))); + } + } + private static int getRemainingCores() { return InfrastructureAnalyzer.getLocalParallelism(); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 0641f3eebec..dab2a41a211 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -53,6 +53,7 @@ public class FederatedParamservTest extends AutomatedTestBase { private final String _utype; private final String _freq; private final String _scheme; + private final String _runtime_balancing; private final String _data_distribution; // parameters @@ -61,21 +62,26 @@ public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE", "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "SUBSAMPLE", "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "BALANCE", "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "SHUFFLE", "IMBALANCED"}, - /*{"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE", "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "BALANCED"}, - {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE", "BALANCED"}, - {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "BALANCED"}*/ + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"CNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"CNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + /*{"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE", "NONE" , "IMBALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE", "" , "BALANCED"}, + {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, + {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE", "" , "BALANCED"}, + {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}*/ }); } public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, - int epochs, double eta, String utype, String freq, String scheme, String data_distribution) { + int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String data_distribution) { _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; _dataSetSize = dataSetSize; @@ -85,6 +91,7 @@ public FederatedParamservTest(String networkType, int numFederatedWorkers, int d _utype = utype; _freq = freq; _scheme = scheme; + _runtime_balancing = runtime_balancing; _data_distribution = data_distribution; } @@ -168,6 +175,7 @@ private void federatedParamserv(ExecMode mode) { "utype=" + _utype, "freq=" + _freq, "scheme=" + _scheme, + "runtime_balancing=" + _runtime_balancing, "network_type=" + _networkType, "channels=" + C, "hin=" + Hin, diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml index a2196001782..ccfb3d22a7e 100644 --- a/src/test/scripts/functions/federated/paramserv/CNN.dml +++ b/src/test/scripts/functions/federated/paramserv/CNN.dml @@ -163,8 +163,8 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, string utype, string freq, int batch_size, string scheme, double eta, - int C, int Hin, int Win, + int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + double eta, int C, int Hin, int Win, int seed = -1) return (list[unknown] model) { @@ -210,7 +210,7 @@ train_paramserv = function(matrix[double] X, matrix[double] y, model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, - scheme=scheme, hyperparams=hyperparams) + scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) } /* diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml index 33d1a0c7b73..affbea7ee5d 100644 --- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml +++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml @@ -27,9 +27,9 @@ features = read($features) labels = read($labels) if($network_type == "TwoNN") { - model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $eta, $seed) + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed) } else { - model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $eta, $channels, $hin, $win, $seed) + model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed) } print(toString(model)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml index 2e5f9b5bfb2..92d85c75c70 100644 --- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml +++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml @@ -125,8 +125,8 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, string utype, string freq, int batch_size, string scheme, double eta, - int seed = -1) + int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + double eta, int seed = -1) return (list[unknown] model) { N = nrow(X) # num examples @@ -154,7 +154,7 @@ train_paramserv = function(matrix[double] X, matrix[double] y, model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, - scheme=scheme, hyperparams=hyperparams) + scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) } /* From 86aa40852f61f4e06a72d76c3c7a22378c08a4ed Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Sun, 13 Dec 2020 13:14:50 +0100 Subject: [PATCH 09/16] [SYSTEMDS-2550] Implemented runtime balancing: cycling min avg and max --- .../org/apache/sysds/parser/Statement.java | 5 +- .../paramserv/FederatedPSControlThread.java | 110 +++++++++++------- .../cp/ParamservBuiltinCPInstruction.java | 15 ++- .../paramserv/FederatedParamservTest.java | 14 ++- 4 files changed, 93 insertions(+), 51 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 8cf7786229a..5fbd9678eb4 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -89,7 +89,7 @@ public enum PSFrequency { } public static final String PS_RUNTIME_BALANCING = "runtime_balancing"; public enum PSRuntimeBalancing { - NONE, CYCLE + NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH } public static final String PS_EPOCHS = "epochs"; public static final String PS_BATCH_SIZE = "batchsize"; @@ -111,7 +111,8 @@ public enum PSCheckpointing { // prefixed with code: "1701-NCC-" to not overwrite anything public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size"; public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size"; - public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches"; + public static final String PS_FED_POSS_BATCHES_LOCAL = "1701-NCC-poss_batches_local"; + public static final String PS_FED_NUM_BATCHES_GLOBAL = "1701-NCC-num_batches_global"; public static final String PS_FED_NAMESPACE = "1701-NCC-namespace"; public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname"; public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 91ac14bc65f..c3687c0a866 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -56,16 +56,18 @@ public class FederatedPSControlThread extends PSWorker implements Callable Statement.PSRuntimeBalancing _runtimeBalancing; FederatedData _featuresData; FederatedData _labelsData; - final long _batchCounterVarID; + final long _localBatchNumVarID; final long _modelVarID; - int _totalNumBatches; + int _numBatchesPerGlobalEpoch; + int _possibleBatchesPerLocalEpoch; - public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) { super(workerID, updFunc, freq, epochs, batchSize, ec, ps); + _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch; _runtimeBalancing = runtimeBalancing; // generate the IDs for model and batch counter. These get overwritten on the federated worker each time - _batchCounterVarID = FederationUtils.getNextFedDataID(); + _localBatchNumVarID = FederationUtils.getNextFedDataID(); _modelVarID = FederationUtils.getNextFedDataID(); } @@ -79,7 +81,10 @@ public void setup() { // calculate number of batches and get data size long dataSize = _features.getNumRows(); - _totalNumBatches = (int) Math.ceil((double) dataSize / _batchSize); + _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize); + if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) { + _numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch; + } // serialize program // create program blocks for the instruction filtering @@ -112,13 +117,14 @@ public void setup() { _featuresData.getVarID(), new setupFederatedWorker(_batchSize, dataSize, - _totalNumBatches, + _possibleBatchesPerLocalEpoch, + _numBatchesPerGlobalEpoch, programSerialized, _inst.getNamespace(), _inst.getFunctionName(), _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), - _batchCounterVarID, + _localBatchNumVarID, _modelVarID ) )); @@ -140,7 +146,8 @@ private static class setupFederatedWorker extends FederatedUDF { private static final long serialVersionUID = -3148991224792675607L; long _batchSize; long _dataSize; - long _numBatches; + int _possibleBatchesPerLocalEpoch; + int _numBatchesPerGlobalEpoch; String _programString; String _namespace; String _gradientsFunctionName; @@ -149,11 +156,13 @@ private static class setupFederatedWorker extends FederatedUDF { long _batchCounterVarID; long _modelVarID; - protected setupFederatedWorker(long batchSize, long dataSize, long numBatches, String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long batchCounterVarID, long modelVarID) { + protected setupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, int numBatchesPerGlobalEpoch, String programString, + String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long batchCounterVarID, long modelVarID) { super(new long[]{}); _batchSize = batchSize; _dataSize = dataSize; - _numBatches = numBatches; + _possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch; + _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch; _programString = programString; _namespace = namespace; _gradientsFunctionName = gradientsFunctionName; @@ -171,7 +180,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // set variables to ec ec.setVariable(Statement.PS_FED_BATCH_SIZE, new IntObject(_batchSize)); ec.setVariable(Statement.PS_FED_DATA_SIZE, new IntObject(_dataSize)); - ec.setVariable(Statement.PS_FED_NUM_BATCHES, new IntObject(_numBatches)); + ec.setVariable(Statement.PS_FED_POSS_BATCHES_LOCAL, new IntObject(_possibleBatchesPerLocalEpoch)); + ec.setVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL, new IntObject(_numBatchesPerGlobalEpoch)); ec.setVariable(Statement.PS_FED_NAMESPACE, new StringObject(_namespace)); ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName)); ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName)); @@ -218,7 +228,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // remove variables from ec ec.removeVariable(Statement.PS_FED_BATCH_SIZE); ec.removeVariable(Statement.PS_FED_DATA_SIZE); - ec.removeVariable(Statement.PS_FED_NUM_BATCHES); + ec.removeVariable(Statement.PS_FED_POSS_BATCHES_LOCAL); + ec.removeVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL); ec.removeVariable(Statement.PS_FED_NAMESPACE); ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME); ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME); @@ -242,10 +253,10 @@ public Void call() throws Exception { try { switch (_freq) { case BATCH: - computeBatch(_totalNumBatches); + computeBatches(); break; case EPOCH: - computeEpoch(); + computeEpochs(); break; /*case NBATCH: break;*/ @@ -269,19 +280,25 @@ protected void pushGradients(ListObject gradients) { _ps.push(_workerID, gradients); } + static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) { + return currentLocalBatchNumber % possibleBatchesPerLocalEpoch; + } + /** * Computes all epochs and synchronizes after each batch - * - * @param numBatches the number of batches per epoch */ - protected void computeBatch(int numBatches) { + protected void computeBatches() { + int currentLocalBatchNumber = 0; + for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { - for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) { + for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) { + int localBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); ListObject model = pullModel(); - ListObject gradients = computeBatchGradients(model, batchCounter); + ListObject gradients = computeBatchGradients(model, localBatchNum); pushGradients(gradients); ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); + System.out.println("[+] " + this.getWorkerName() + " completed batch " + localBatchNum); } System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); } @@ -291,13 +308,12 @@ protected void computeBatch(int numBatches) { * Computes a single specified batch on the federated worker * * @param model the current model from the parameter server - * @param batchCounter the current batch number needed for slicing the features and labels + * @param localBatchNum the current batch number needed for slicing the features and labels * @return the gradient vector */ - protected ListObject computeBatchGradients(ListObject model, int batchCounter) { - // put batch counter on federated worker - Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new IntObject(batchCounter))); - + protected ListObject computeBatchGradients(ListObject model, int localBatchNum) { + // put local batch num on federated worker + Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _localBatchNumVarID, new IntObject(localBatchNum))); // put current model on federated worker Future putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model)); @@ -312,7 +328,7 @@ protected ListObject computeBatchGradients(ListObject model, int batchCounter) { // create and execute the udf on the remote worker Future udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, _featuresData.getVarID(), - new federatedComputeBatchGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID, _modelVarID}) + new federatedComputeBatchGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _localBatchNumVarID, _modelVarID}) )); try { @@ -339,7 +355,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // read in data by varid MatrixObject features = (MatrixObject) data[0]; MatrixObject labels = (MatrixObject) data[1]; - long batchCounter = ((IntObject) data[2]).getLongValue(); + long localBatchNum = ((IntObject) data[2]).getLongValue(); ListObject model = (ListObject) data[3]; // get data from execution context @@ -349,8 +365,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue(); // slice batch from feature and label matrix - long begin = batchCounter * batchSize + 1; - long end = Math.min((batchCounter + 1) * batchSize, dataSize); + long begin = localBatchNum * batchSize + 1; + long end = Math.min((localBatchNum + 1) * batchSize, dataSize); MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end); MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end); @@ -393,11 +409,14 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { /** * Computes all epochs and synchronizes after each one */ - protected void computeEpoch() { + protected void computeEpochs() { for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { + // TODO: Calc localStartBatchNum + int localStartBatchNum = 0; + // Pull the global parameters from ps ListObject model = pullModel(); - ListObject gradients = computeEpochGradients(model); + ListObject gradients = computeEpochGradients(model, localStartBatchNum); pushGradients(gradients); System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); ParamservUtils.cleanupListObject(model); @@ -411,12 +430,14 @@ protected void computeEpoch() { * @param model the current model from the parameter server * @return the gradient vector */ - protected ListObject computeEpochGradients(ListObject model) { + protected ListObject computeEpochGradients(ListObject model, int localStartBatchNum) { + // put local batch num on federated worker + Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _localBatchNumVarID, new IntObject(localStartBatchNum))); // put current model on federated worker Future putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model)); try { - if(!putParamsResponse.get().isSuccessful()) + if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful()) throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful"); } catch(Exception e) { @@ -426,7 +447,7 @@ protected ListObject computeEpochGradients(ListObject model) { // create and execute the udf on the remote worker Future udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, _featuresData.getVarID(), - new federatedComputeEpochGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID}) + new federatedComputeEpochGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _localBatchNumVarID, _modelVarID}) )); try { @@ -453,12 +474,14 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // read in data by varid MatrixObject features = (MatrixObject) data[0]; MatrixObject labels = (MatrixObject) data[1]; - ListObject model = (ListObject) data[2]; + int localStartBatchNum = (int) ((IntObject) data[2]).getLongValue(); + ListObject model = (ListObject) data[3]; // get data from execution context long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue(); long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue(); - long numBatches = ((IntObject) ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue(); + int possibleBatchesPerLocalEpoch = (int) ((IntObject) ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue(); + int numBatchesPerGlobalEpoch = (int) ((IntObject) ec.getVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL)).getLongValue(); String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue(); String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue(); String aggregationFuctionName = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue(); @@ -488,22 +511,23 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { Instruction aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs, func.getInputParamNames(), outputNames, "aggregation function"); DataIdentifier aggregationOutput = outputs.get(0); - - ListObject accGradients = null; + + int currentLocalBatchNumber = localStartBatchNum; // prepare execution context ec.setVariable(Statement.PS_MODEL, model); - for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) { + for (int batchCounter = 0; batchCounter < numBatchesPerGlobalEpoch; batchCounter++) { + int localBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch); + // slice batch from feature and label matrix - long begin = batchCounter * batchSize + 1; - long end = Math.min((batchCounter + 1) * batchSize, dataSize); + long begin = localBatchNum * batchSize + 1; + long end = Math.min((localBatchNum + 1) * batchSize, dataSize); MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end); MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end); // prepare execution context ec.setVariable(Statement.PS_FEATURES, bFeatures); ec.setVariable(Statement.PS_LABELS, bLabels); - boolean localUpdate = batchCounter < numBatches - 1; // calculate intermediate gradients gradientsInstruction.processInstruction(ec); @@ -513,7 +537,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false); // Update the local model with gradients - if(localUpdate) { + if(batchCounter < numBatchesPerGlobalEpoch - 1) { // Invoke the aggregate function aggregationInstruction.processInstruction(ec); // Get the new model @@ -526,8 +550,10 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { } // clean up sliced batch + ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); ParamservUtils.cleanupData(ec, Statement.PS_LABELS); + System.out.println("[+]" + " completed batch " + localBatchNum); } // model clean up - doing this twice is not an issue diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index ec6903a02cc..3132e180eeb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -126,6 +126,16 @@ private void runFederated(ExecutionContext ec) { List pLabels = result._pLabels; int workerNum = result._workerNum; + // calculate runtime balancing + int numBatchesPerEpoch = 0; + if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) { + numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) { + numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) { + numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize()); + } + // setup threading BasicThreadFactory factory = new BasicThreadFactory.Builder() .namingPattern("workers-pool-thread-%d").build(); @@ -143,8 +153,9 @@ private void runFederated(ExecutionContext ec) { ListObject model = ec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers + int finalNumBatchesPerEpoch = numBatchesPerEpoch; List threads = IntStream.range(0, workerNum) - .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps)) + .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps)) .collect(Collectors.toList()); if(workerNum != threads.size()) { @@ -378,7 +389,7 @@ private PSRuntimeBalancing getRuntimeBalancing() { return PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING)); } catch (IllegalArgumentException e) { throw new DMLRuntimeException(String.format("Paramserv function: " - + "not support '%s' runtime balancing.", getParam(PS_FREQUENCY))); + + "not support '%s' runtime balancing.", getParam(PS_RUNTIME_BALANCING))); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index dab2a41a211..da8efb843b1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -62,10 +62,12 @@ public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update // type, update frequency - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"CNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"CNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"TwoNN", 2, 4, 1, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, + {"CNN", 2, 4, 1, 1, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"CNN", 2, 4, 1, 1, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, /*{"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE", "NONE" , "IMBALANCED"}, @@ -183,7 +185,9 @@ private void federatedParamserv(ExecMode mode) { "seed=" + 25)); programArgs = programArgsList.toArray(new String[0]); - LOG.debug(runTest(null)); + // TODO: Switch back + //LOG.debug(runTest(null)); + System.out.println(runTest(null)); Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst()); // shut down threads From c123a0dd5aa3c2914106328057267b4e4161e250 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Sun, 13 Dec 2020 13:52:48 +0100 Subject: [PATCH 10/16] [SYSTEMDS-2550] Refactoring and make code more compact --- .../org/apache/sysds/parser/Statement.java | 1 - .../paramserv/FederatedPSControlThread.java | 215 ++++++------------ 2 files changed, 66 insertions(+), 150 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 5fbd9678eb4..14ce90e29fb 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -112,7 +112,6 @@ public enum PSCheckpointing { public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size"; public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size"; public static final String PS_FED_POSS_BATCHES_LOCAL = "1701-NCC-poss_batches_local"; - public static final String PS_FED_NUM_BATCHES_GLOBAL = "1701-NCC-num_batches_global"; public static final String PS_FED_NAMESPACE = "1701-NCC-namespace"; public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname"; public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index c3687c0a866..3cbb3c0286d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -56,10 +56,11 @@ public class FederatedPSControlThread extends PSWorker implements Callable Statement.PSRuntimeBalancing _runtimeBalancing; FederatedData _featuresData; FederatedData _labelsData; - final long _localBatchNumVarID; + final long _localStartBatchNumVarID; final long _modelVarID; int _numBatchesPerGlobalEpoch; int _possibleBatchesPerLocalEpoch; + boolean _cycleStartAt0 = false; public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) { super(workerID, updFunc, freq, epochs, batchSize, ec, ps); @@ -67,7 +68,7 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch; _runtimeBalancing = runtimeBalancing; // generate the IDs for model and batch counter. These get overwritten on the federated worker each time - _localBatchNumVarID = FederationUtils.getNextFedDataID(); + _localStartBatchNumVarID = FederationUtils.getNextFedDataID(); _modelVarID = FederationUtils.getNextFedDataID(); } @@ -118,13 +119,12 @@ public void setup() { new setupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, - _numBatchesPerGlobalEpoch, programSerialized, _inst.getNamespace(), _inst.getFunctionName(), _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), - _localBatchNumVarID, + _localStartBatchNumVarID, _modelVarID ) )); @@ -147,7 +147,6 @@ private static class setupFederatedWorker extends FederatedUDF { long _batchSize; long _dataSize; int _possibleBatchesPerLocalEpoch; - int _numBatchesPerGlobalEpoch; String _programString; String _namespace; String _gradientsFunctionName; @@ -156,13 +155,12 @@ private static class setupFederatedWorker extends FederatedUDF { long _batchCounterVarID; long _modelVarID; - protected setupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, int numBatchesPerGlobalEpoch, String programString, + protected setupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long batchCounterVarID, long modelVarID) { super(new long[]{}); _batchSize = batchSize; _dataSize = dataSize; _possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch; - _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch; _programString = programString; _namespace = namespace; _gradientsFunctionName = gradientsFunctionName; @@ -181,7 +179,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ec.setVariable(Statement.PS_FED_BATCH_SIZE, new IntObject(_batchSize)); ec.setVariable(Statement.PS_FED_DATA_SIZE, new IntObject(_dataSize)); ec.setVariable(Statement.PS_FED_POSS_BATCHES_LOCAL, new IntObject(_possibleBatchesPerLocalEpoch)); - ec.setVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL, new IntObject(_numBatchesPerGlobalEpoch)); ec.setVariable(Statement.PS_FED_NAMESPACE, new StringObject(_namespace)); ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName)); ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName)); @@ -229,7 +226,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ec.removeVariable(Statement.PS_FED_BATCH_SIZE); ec.removeVariable(Statement.PS_FED_DATA_SIZE); ec.removeVariable(Statement.PS_FED_POSS_BATCHES_LOCAL); - ec.removeVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL); ec.removeVariable(Statement.PS_FED_NAMESPACE); ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME); ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME); @@ -253,13 +249,14 @@ public Void call() throws Exception { try { switch (_freq) { case BATCH: - computeBatches(); + computeWithBatchUpdates(); break; + /*case NBATCH: + computeWithNBatchUpdates(); + break; */ case EPOCH: - computeEpochs(); + computeWithEpochUpdates(); break; - /*case NBATCH: - break;*/ default: throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq)); } @@ -285,138 +282,41 @@ static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possi } /** - * Computes all epochs and synchronizes after each batch + * Computes all epochs and updates after each batch */ - protected void computeBatches() { - int currentLocalBatchNumber = 0; - + protected void computeWithBatchUpdates() { for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { + int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch; + for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) { - int localBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); + int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); ListObject model = pullModel(); - ListObject gradients = computeBatchGradients(model, localBatchNum); + ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum); pushGradients(gradients); ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); - System.out.println("[+] " + this.getWorkerName() + " completed batch " + localBatchNum); } System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); } } /** - * Computes a single specified batch on the federated worker - * - * @param model the current model from the parameter server - * @param localBatchNum the current batch number needed for slicing the features and labels - * @return the gradient vector - */ - protected ListObject computeBatchGradients(ListObject model, int localBatchNum) { - // put local batch num on federated worker - Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _localBatchNumVarID, new IntObject(localBatchNum))); - // put current model on federated worker - Future putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model)); - - try { - if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful()) - throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful"); - } - catch(Exception e) { - throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage()); - } - - // create and execute the udf on the remote worker - Future udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - _featuresData.getVarID(), - new federatedComputeBatchGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _localBatchNumVarID, _modelVarID}) - )); - - try { - Object[] responseData = udfResponse.get().getData(); - return (ListObject) responseData[0]; - } - catch(Exception e) { - throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage()); - } - } - - /** - * This is the code that will be executed on the federated Worker when computing a single batch + * Computes all epochs and updates after N batches */ - private static class federatedComputeBatchGradients extends FederatedUDF { - private static final long serialVersionUID = -3652112393963053475L; - - protected federatedComputeBatchGradients(long[] inIDs) { - super(inIDs); - } - - @Override - public FederatedResponse execute(ExecutionContext ec, Data... data) { - // read in data by varid - MatrixObject features = (MatrixObject) data[0]; - MatrixObject labels = (MatrixObject) data[1]; - long localBatchNum = ((IntObject) data[2]).getLongValue(); - ListObject model = (ListObject) data[3]; - - // get data from execution context - long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue(); - long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue(); - String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue(); - String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue(); - - // slice batch from feature and label matrix - long begin = localBatchNum * batchSize + 1; - long end = Math.min((localBatchNum + 1) * batchSize, dataSize); - MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end); - MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end); - - // prepare execution context - ec.setVariable(Statement.PS_MODEL, model); - ec.setVariable(Statement.PS_FEATURES, bFeatures); - ec.setVariable(Statement.PS_LABELS, bLabels); - - // recreate gradient instruction and output - FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false); - ArrayList inputs = func.getInputParams(); - ArrayList outputs = func.getOutputParams(); - CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); - ArrayList outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); - Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs, - func.getInputParamNames(), outputNames, "gradient function"); - DataIdentifier gradientsOutput = outputs.get(0); - - // calculate and gradients - gradientsInstruction.processInstruction(ec); - ListObject gradients = ec.getListObject(gradientsOutput.getName()); - - // clean up sliced batch - ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); - ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); - ParamservUtils.cleanupData(ec, Statement.PS_LABELS); - - // model clean up - doing this twice is not an issue - ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString()); - ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); - - // return - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients); - } + protected void computeWithNBatchUpdates() { + System.out.println("Not implemented yet"); } /** - * Computes all epochs and synchronizes after each one + * Computes all epochs and updates after each epoch */ - protected void computeEpochs() { + protected void computeWithEpochUpdates() { for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { - // TODO: Calc localStartBatchNum - int localStartBatchNum = 0; + int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch; // Pull the global parameters from ps ListObject model = pullModel(); - ListObject gradients = computeEpochGradients(model, localStartBatchNum); + ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, localStartBatchNum, true); pushGradients(gradients); System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); ParamservUtils.cleanupListObject(model); @@ -424,15 +324,23 @@ protected void computeEpochs() { } } + protected ListObject computeGradientsForNBatches(ListObject model, int numBatchesToCompute, int localStartBatchNum) { + return computeGradientsForNBatches(model, numBatchesToCompute, localStartBatchNum, false); + } + /** - * Computes one epoch on the federated worker and updates the model local + * Computes the gradients of n batches on the federated worker and is able to update the model local. + * Returns the gradients. * * @param model the current model from the parameter server + * @param localStartBatchNum the batch to start from + * @param localUpdate whether to update the model locally + * * @return the gradient vector */ - protected ListObject computeEpochGradients(ListObject model, int localStartBatchNum) { - // put local batch num on federated worker - Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _localBatchNumVarID, new IntObject(localStartBatchNum))); + protected ListObject computeGradientsForNBatches(ListObject model, int numBatchesToCompute, int localStartBatchNum, boolean localUpdate) { + // put local start batch num on federated worker + Future putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _localStartBatchNumVarID, new IntObject(localStartBatchNum))); // put current model on federated worker Future putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model)); @@ -447,7 +355,7 @@ protected ListObject computeEpochGradients(ListObject model, int localStartBatch // create and execute the udf on the remote worker Future udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, _featuresData.getVarID(), - new federatedComputeEpochGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _localBatchNumVarID, _modelVarID}) + new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _localStartBatchNumVarID, _modelVarID}, numBatchesToCompute, localUpdate) )); try { @@ -460,13 +368,17 @@ protected ListObject computeEpochGradients(ListObject model, int localStartBatch } /** - * This is the code that will be executed on the federated Worker when computing one epoch + * This is the code that will be executed on the federated Worker when computing one gradients for n batches */ - private static class federatedComputeEpochGradients extends FederatedUDF { + private static class federatedComputeGradientsForNBatches extends FederatedUDF { private static final long serialVersionUID = -3075901536748794832L; + int _numBatchesToCompute; + boolean _localUpdate; - protected federatedComputeEpochGradients(long[] inIDs) { + protected federatedComputeGradientsForNBatches(long[] inIDs, int numBatchesToCompute, boolean localUpdate) { super(inIDs); + _numBatchesToCompute = numBatchesToCompute; + _localUpdate = localUpdate; } @Override @@ -481,7 +393,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue(); long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue(); int possibleBatchesPerLocalEpoch = (int) ((IntObject) ec.getVariable(Statement.PS_FED_POSS_BATCHES_LOCAL)).getLongValue(); - int numBatchesPerGlobalEpoch = (int) ((IntObject) ec.getVariable(Statement.PS_FED_NUM_BATCHES_GLOBAL)).getLongValue(); String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue(); String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue(); String aggregationFuctionName = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue(); @@ -499,24 +410,28 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { func.getInputParamNames(), outputNames, "gradient function"); DataIdentifier gradientsOutput = outputs.get(0); - // recreate aggregation instruction and output - func = ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, false); - inputs = func.getInputParams(); - outputs = func.getOutputParams(); - boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); - outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); - Instruction aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs, - func.getInputParamNames(), outputNames, "aggregation function"); - DataIdentifier aggregationOutput = outputs.get(0); - ListObject accGradients = null; + // recreate aggregation instruction and output if needed + Instruction aggregationInstruction = null; + DataIdentifier aggregationOutput = null; + if(_localUpdate && _numBatchesToCompute > 1) { + func = ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, false); + inputs = func.getInputParams(); + outputs = func.getOutputParams(); + boundInputs = inputs.stream() + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); + outputNames = outputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs, + func.getInputParamNames(), outputNames, "aggregation function"); + aggregationOutput = outputs.get(0); + } + ListObject accGradients = null; int currentLocalBatchNumber = localStartBatchNum; // prepare execution context ec.setVariable(Statement.PS_MODEL, model); - for (int batchCounter = 0; batchCounter < numBatchesPerGlobalEpoch; batchCounter++) { + for (int batchCounter = 0; batchCounter < _numBatchesToCompute; batchCounter++) { int localBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch); // slice batch from feature and label matrix @@ -529,16 +444,18 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ec.setVariable(Statement.PS_FEATURES, bFeatures); ec.setVariable(Statement.PS_LABELS, bLabels); - // calculate intermediate gradients + // calculate gradients for batch gradientsInstruction.processInstruction(ec); ListObject gradients = ec.getListObject(gradientsOutput.getName()); + // accrue the computed gradients - In the single batch case this is just a list copy // TODO: is this equivalent for momentum based and AMS prob? accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false); - // Update the local model with gradients - if(batchCounter < numBatchesPerGlobalEpoch - 1) { + // update the local model with gradients if needed + if(_localUpdate && batchCounter < _numBatchesToCompute - 1) { // Invoke the aggregate function + assert aggregationInstruction != null; aggregationInstruction.processInstruction(ec); // Get the new model model = ec.getListObject(aggregationOutput.getName()); From 39b4cf30b83c637c14e6919ebdaa94bb64ce6474 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Sun, 13 Dec 2020 14:29:29 +0100 Subject: [PATCH 11/16] [SYSTEMDS-2550] new test cases for runtime balancing --- .../org/apache/sysds/parser/Statement.java | 2 +- .../paramserv/FederatedPSControlThread.java | 18 +++++--- .../dp/FederatedDataPartitioner.java | 6 +-- .../cp/ParamservBuiltinCPInstruction.java | 1 - .../paramserv/FederatedParamservTest.java | 46 ++++++++++--------- .../functions/federated/paramserv/CNN.dml | 9 ++-- .../paramserv/FederatedParamservTest.dml | 4 +- .../functions/federated/paramserv/TwoNN.dml | 7 +-- 8 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 14ce90e29fb..6767d857e35 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -99,7 +99,7 @@ public enum PSScheme { DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE } public enum FederatedPSScheme { - KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE, SUBSAMPLE, BALANCE + KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX, SUBSAMPLE_TO_MIN, BALANCE_TO_AVG } public static final String PS_HYPER_PARAMS = "hyperparams"; public static final String PS_CHECKPOINTING = "checkpointing"; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 3cbb3c0286d..b289716700c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -87,6 +87,11 @@ public void setup() { _numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch; } + if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH || _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) { + System.out.println("ERROR: Not implemented yet!"); + return; + } + // serialize program // create program blocks for the instruction filtering String programSerialized; @@ -232,7 +237,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID); ec.removeVariable(Statement.PS_FED_MODEL_VARID); ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS); - ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); } @@ -304,7 +308,7 @@ protected void computeWithBatchUpdates() { * Computes all epochs and updates after N batches */ protected void computeWithNBatchUpdates() { - System.out.println("Not implemented yet"); + System.out.println("ERROR: Not implemented yet!"); } /** @@ -449,7 +453,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ListObject gradients = ec.getListObject(gradientsOutput.getName()); // accrue the computed gradients - In the single batch case this is just a list copy - // TODO: is this equivalent for momentum based and AMS prob? + // is this equivalent for momentum based and AMS prob? accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false); // update the local model with gradients if needed @@ -462,18 +466,18 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // Set new model in execution context ec.setVariable(Statement.PS_MODEL, model); // clean up gradients and result - ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS); ParamservUtils.cleanupListObject(ec, aggregationOutput.getName()); } - // clean up sliced batch - ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); + // clean up + ParamservUtils.cleanupListObject(ec, gradientsOutput.getName()); ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); ParamservUtils.cleanupData(ec, Statement.PS_LABELS); + ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); System.out.println("[+]" + " completed batch " + localBatchNum); } - // model clean up - doing this twice is not an issue + // model clean up ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString()); ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java index c2978669b04..29be7821478 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java @@ -35,13 +35,13 @@ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) { case SHUFFLE: _scheme = new ShuffleFederatedScheme(); break; - case REPLICATE: + case REPLICATE_TO_MAX: _scheme = new ReplicateFederatedScheme(); break; - case SUBSAMPLE: + case SUBSAMPLE_TO_MIN: _scheme = new SubsampleFederatedScheme(); break; - case BALANCE: + case BALANCE_TO_AVG: _scheme = new BalanceFederatedScheme(); break; default: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 3132e180eeb..ca57cb230b6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -38,7 +38,6 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.lops.LopProperties; -import org.apache.sysds.parser.Statement; import org.apache.sysds.parser.Statement.PSFrequency; import org.apache.sysds.parser.Statement.PSModeType; import org.apache.sysds.parser.Statement.PSScheme; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index da8efb843b1..686a212359a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -60,25 +60,31 @@ public class FederatedParamservTest extends AutomatedTestBase { @Parameterized.Parameters public static Collection parameters() { return Arrays.asList(new Object[][] { - // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update - // type, update frequency - {"TwoNN", 2, 4, 1, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, - {"CNN", 2, 4, 1, 1, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"CNN", 2, 4, 1, 1, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - /*{"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE", "NONE" , "IMBALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "BATCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "BSP", "EPOCH", "SHUFFLE", "" , "BALANCED"}, - {"CNN", 2, 4, 1, 5, 0.01, "ASP", "EPOCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}, - {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH", "SHUFFLE", "" , "BALANCED"}, - {"CNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "" , "BALANCED"}*/ + // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency + + // basic functionality + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "IMBALANCED"}, + {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "IMBALANCED"}, + + // runtime balancing + /*{"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, + + // data partitioning + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"}, + + // complex balanced tests + {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}*/ }); } @@ -108,12 +114,10 @@ public void federatedParamservSingleNode() { federatedParamserv(ExecMode.SINGLE_NODE); } - /* @Test public void federatedParamservHybrid() { federatedParamserv(ExecMode.HYBRID); } - */ private void federatedParamserv(ExecMode mode) { // config diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml index ccfb3d22a7e..69c7e760442 100644 --- a/src/test/scripts/functions/federated/paramserv/CNN.dml +++ b/src/test/scripts/functions/federated/paramserv/CNN.dml @@ -67,7 +67,7 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov */ train = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, string utype, string freq, int batch_size, string scheme, double eta, + int epochs, int batch_size, double eta, int C, int Hin, int Win, int seed = -1) return (list[unknown] model) { @@ -163,7 +163,7 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, double eta, int C, int Hin, int Win, int seed = -1) return (list[unknown] model) { @@ -208,8 +208,9 @@ train_paramserv = function(matrix[double] X, matrix[double] y, # Use paramserv function model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, - upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", - utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, + upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", + agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", + k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) } diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml index affbea7ee5d..10d2cc7f028 100644 --- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml +++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml @@ -27,9 +27,9 @@ features = read($features) labels = read($labels) if($network_type == "TwoNN") { - model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed) + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed) } else { - model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed) + model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed) } print(toString(model)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml index 92d85c75c70..9bd49d85917 100644 --- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml +++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml @@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, double eta, int seed = -1) return (list[unknown] model) { @@ -152,8 +152,9 @@ train_paramserv = function(matrix[double] X, matrix[double] y, hyperparams = list(learning_rate=eta) # Use paramserv function model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val, - upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", - utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, + upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", + agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", + k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) } From 32bf44c77e92cb96cd53aa00c6378a3eb68244c5 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Fri, 18 Dec 2020 17:15:39 +0100 Subject: [PATCH 12/16] [SYSTEMDS-2550] Rebase fixes before pullrequest --- .../dp/DataPartitionFederatedScheme.java | 53 +++++++++++++++++-- .../paramserv/dp/ShuffleFederatedScheme.java | 22 ++------ .../paramserv/FederatedParamservTest.java | 4 +- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index b75acb6817e..e82c3ba6d94 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -19,7 +19,6 @@ package org.apache.sysds.runtime.controlprogram.paramserv.dp; -import org.apache.sysds.api.mlcontext.Matrix; import org.apache.sysds.common.Types; import org.apache.sysds.lops.compile.Dag; import org.apache.sysds.runtime.DMLRuntimeException; @@ -27,6 +26,11 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; @@ -61,12 +65,10 @@ public Result(List pFeatures, List pLabels, int work */ static List sliceFederatedMatrix(MatrixObject fedMatrix) { if (fedMatrix.isFederated(FederationMap.FType.ROW)) { - List slices = Collections.synchronizedList(new ArrayList<>()); fedMatrix.getFedMapping().forEachParallel((range, data) -> { // Create sliced matrix object MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX)); - // Warning needs MetaDataFormat instead of MetaData slice.setMetaData(new MetaDataFormat( new MatrixCharacteristics(range.getSize(0), range.getSize(1)), Types.FileFormat.BINARY) @@ -121,4 +123,49 @@ public BalanceMetrics(long minRows, long avgRows, long maxRows) { this._maxRows = maxRows; } } + + /** + * Just a mat multiply used to shuffle with a provided shuffle matrixBlock + * + * @param m the input matrix object + * @param permutationMatrixBlock the shuffle matrix block + */ + static void shuffle(MatrixObject m, MatrixBlock permutationMatrixBlock) { + // matrix multiply + m.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( + permutationMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + m.release(); + } + + /** + * Takes a MatrixObjects and extends it to the chosen number of rows by random replication + * + * @param m the input matrix object + */ + static void replicateTo(MatrixObject m, MatrixBlock replicateMatrixBlock) { + // matrix multiply and append + MatrixBlock replicatedFeatures = replicateMatrixBlock.aggregateBinaryOperations( + replicateMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))); + + m.acquireModify(m.acquireReadAndRelease().append(replicatedFeatures, new MatrixBlock(), false)); + m.release(); + } + + /** + * Just a mat multiply used to subsample with a provided subsample matrixBlock + * + * @param m the input matrix object + * @param subsampleMatrixBlock the subsample matrix block + */ + static void subsampleTo(MatrixObject m, MatrixBlock subsampleMatrixBlock) { + // matrix multiply + m.acquireModify(subsampleMatrixBlock.aggregateBinaryOperations( + subsampleMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(), + new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) + )); + m.release(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index f5f11e1fdbd..73b39a4c844 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -27,12 +27,8 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysds.runtime.functionobjects.Multiply; -import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import java.util.List; import java.util.concurrent.Future; @@ -79,21 +75,9 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // generate permutation matrix MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); - - // matrix multiplies - features.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( - permutationMatrixBlock, features.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - features.release(); - - labels.acquireModify(permutationMatrixBlock.aggregateBinaryOperations( - permutationMatrixBlock, labels.acquireReadAndRelease(), new MatrixBlock(), - new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())) - )); - labels.release(); - + shuffle(features, permutationMatrixBlock); + shuffle(labels, permutationMatrixBlock); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); } } -} +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 686a212359a..b0d4065ed18 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -69,7 +69,7 @@ public static Collection parameters() { {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "IMBALANCED"}, // runtime balancing - /*{"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, @@ -84,7 +84,7 @@ public static Collection parameters() { // complex balanced tests {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}*/ + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"} }); } From d45fd1ff4aebdad634d0a6c39fe0dd8658337b33 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Fri, 18 Dec 2020 17:42:03 +0100 Subject: [PATCH 13/16] [SYSTEMDS-2550] Simplified test cases for master --- .../paramserv/FederatedParamservTest.java | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index b0d4065ed18..76b1c32b061 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -67,24 +67,26 @@ public static Collection parameters() { {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "IMBALANCED"}, {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"}, {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "IMBALANCED"}, - - // runtime balancing - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, - - // data partitioning - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"}, - - // complex balanced tests {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"} + + /* + // runtime balancing + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, + + // data partitioning + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"}, + + // balanced tests + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"} + */ }); } From 3efd864cd3fd3508fd4cb99a1b7b0db1b730581c Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Fri, 18 Dec 2020 17:58:57 +0100 Subject: [PATCH 14/16] [SYSTEMDS-2550] Refactored new federated data partitioning schemes --- ...ederatedScheme.java => BalanceToAvgFederatedScheme.java} | 2 +- .../paramserv/dp/FederatedDataPartitioner.java | 6 +++--- ...eratedScheme.java => ReplicateToMaxFederatedScheme.java} | 2 +- ...eratedScheme.java => SubsampleToMinFederatedScheme.java} | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) rename src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/{BalanceFederatedScheme.java => BalanceToAvgFederatedScheme.java} (98%) rename src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/{ReplicateFederatedScheme.java => ReplicateToMaxFederatedScheme.java} (98%) rename src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/{SubsampleFederatedScheme.java => SubsampleToMinFederatedScheme.java} (98%) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java similarity index 98% rename from src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java rename to src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java index 5cca9383a1e..a2377e9099f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java @@ -35,7 +35,7 @@ import java.util.List; import java.util.concurrent.Future; -public class BalanceFederatedScheme extends DataPartitionFederatedScheme { +public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme { @Override public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java index 29be7821478..d1ebb6cac5d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java @@ -36,13 +36,13 @@ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) { _scheme = new ShuffleFederatedScheme(); break; case REPLICATE_TO_MAX: - _scheme = new ReplicateFederatedScheme(); + _scheme = new ReplicateToMaxFederatedScheme(); break; case SUBSAMPLE_TO_MIN: - _scheme = new SubsampleFederatedScheme(); + _scheme = new SubsampleToMinFederatedScheme(); break; case BALANCE_TO_AVG: - _scheme = new BalanceFederatedScheme(); + _scheme = new BalanceToAvgFederatedScheme(); break; default: throw new DMLRuntimeException(String.format("FederatedDataPartitioner: not support data partition scheme '%s'", scheme)); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java similarity index 98% rename from src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java rename to src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java index 0fdaf25bbbf..73113fb58cf 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.concurrent.Future; -public class ReplicateFederatedScheme extends DataPartitionFederatedScheme { +public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme { @Override public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java similarity index 98% rename from src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java rename to src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java index e21a813a75c..acd3a983442 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.concurrent.Future; -public class SubsampleFederatedScheme extends DataPartitionFederatedScheme { +public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme { @Override public Result doPartitioning(MatrixObject features, MatrixObject labels) { List pFeatures = sliceFederatedMatrix(features); From 54e13dd062ec2a009cbca1172d0eab777f8f8b0a Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Fri, 18 Dec 2020 18:08:08 +0100 Subject: [PATCH 15/16] [SYSTEMDS-2550] Removed accidental wildcard import --- .../cp/ParamservBuiltinCPInstruction.java | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index ca57cb230b6..a2b8d9fbc22 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -19,6 +19,21 @@ package org.apache.sysds.runtime.instructions.cp; +import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN; +import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE; +import static org.apache.sysds.parser.Statement.PS_EPOCHS; +import static org.apache.sysds.parser.Statement.PS_FEATURES; +import static org.apache.sysds.parser.Statement.PS_FREQUENCY; +import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS; +import static org.apache.sysds.parser.Statement.PS_LABELS; +import static org.apache.sysds.parser.Statement.PS_MODE; +import static org.apache.sysds.parser.Statement.PS_MODEL; +import static org.apache.sysds.parser.Statement.PS_PARALLELISM; +import static org.apache.sysds.parser.Statement.PS_SCHEME; +import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN; +import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE; +import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING; + import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -43,6 +58,7 @@ import org.apache.sysds.parser.Statement.PSScheme; import org.apache.sysds.parser.Statement.FederatedPSScheme; import org.apache.sysds.parser.Statement.PSUpdateType; +import org.apache.sysds.parser.Statement.PSRuntimeBalancing; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.LocalVariableMap; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; @@ -66,8 +82,6 @@ import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.utils.Statistics; -import static org.apache.sysds.parser.Statement.*; - public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName()); From dfcc62f008badd6c50c171f1aefcb68972624849 Mon Sep 17 00:00:00 2001 From: Tobias Rieger Date: Fri, 18 Dec 2020 18:25:59 +0100 Subject: [PATCH 16/16] [SYSTEMDS-2550] Disabled FederatedParamservTest output --- .../functions/federated/paramserv/FederatedParamservTest.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 76b1c32b061..5041989e74e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -191,9 +191,7 @@ private void federatedParamserv(ExecMode mode) { "seed=" + 25)); programArgs = programArgsList.toArray(new String[0]); - // TODO: Switch back - //LOG.debug(runTest(null)); - System.out.println(runTest(null)); + LOG.debug(runTest(null)); Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst()); // shut down threads