Skip to content

Commit

Permalink
def various algs into funcs for least squares
Browse files Browse the repository at this point in the history
  • Loading branch information
kunyuan827 committed Jun 3, 2020
1 parent b1040b9 commit a375fe8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 97 deletions.
195 changes: 100 additions & 95 deletions examples/pytorch_least_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# ==============================================================================

import torch
import bluefog.torch as bf
from bluefog.common import topology_util
import matplotlib.pyplot as plt
import argparse

import bluefog.torch as bf
from bluefog.common import topology_util

# Parser
parser = argparse.ArgumentParser(
description="PyTorch ImageNet Example",
Expand All @@ -31,54 +32,44 @@
"--plot-interactive", action='store_true', help="Use plt.show() to present the plot."
)
parser.add_argument(
"--method", type=int, default=0, help="0:exact diffusion. 1:gradient tracking. 2:push-DIGing"
"--method", help="this example supports exact_diffusion, gradient_tracking, and push_diging",
default='exact_diffusion'
)
args = parser.parse_args()


def finalize_plot():
plt.savefig(args.save_plot_file)
if args.plot_interactive:
plt.show()
plt.close()

# # ================== Distributed gradient descent ================================
# # Calculate the solution with distributed gradient descent:
# # x^{k+1} = x^k - alpha * \sum_i A_i.T(A_i x - b_i)
# # it will be used to verify the solution of exact diffusion.
# # ================================================================================
def distributed_grad_descent(maxite = 5000, alpha = 1e-2):
x_opt = torch.zeros(n, 1).to(torch.double)
maxite = 1000
alpha = 1e-2
for i in range(maxite):
grad = A.T.mm(A.mm(x_opt) - b) # local gradient
grad = bf.allreduce(grad, name='gradient') # global gradient
x_opt = x_opt - alpha*grad

# evaluate the convergence of distributed least-squares
# the norm of global gradient is expected to 0 (optimality condition)
global_grad_norm = torch.norm(bf.allreduce(A.T.mm(A.mm(x_opt) - b)), p=2)
print("[DG] Rank {}: global gradient norm: {}".format(
bf.rank(), global_grad_norm))

bf.init()

# The least squares problem is min_x \sum_i^n \|A_i x - b_i\|^2
# where each rank i holds A_i and b_i
# we expect each rank will converge to the global solution after the algorithm
# the norm of local gradient is expected not be be close to 0
# this is because each rank converges to global solution, not local solution
local_grad_norm = torch.norm(A.T.mm(A.mm(x_opt) - b), p=2)
print("[DG] Rank {}: local gradient norm: {}".format(bf.rank(), local_grad_norm))

# Generate data
# y = A@x + ns where ns is Gaussion noise
torch.random.manual_seed(123417 * bf.rank())
m, n = 20, 5
A = torch.randn(m, n).to(torch.double)
x_o = torch.randn(n, 1).to(torch.double)
ns = 0.1*torch.randn(m, 1).to(torch.double)
b = A.mm(x_o) + ns
return x_opt

# Calculate the solution with distributed gradient descent:
# x^{k+1} = x^k - alpha * \sum_i A_i.T(A_i x - b_i)
# it will be used to verify the solution of exact diffusion.
x_opt = torch.zeros(n, 1).to(torch.double)
maxite = 1000
alpha = 1e-2
for i in range(maxite):
grad = A.T.mm(A.mm(x_opt) - b) # local gradient
grad = bf.allreduce(grad, name='gradient') # global gradient
x_opt = x_opt - alpha*grad

# evaluate the convergence of distributed least-squares
# the norm of global gradient is expected to 0 (optimality condition)
global_grad_norm = torch.norm(bf.allreduce(A.T.mm(A.mm(x_opt) - b)), p=2)
print("[DG] Rank {}: global gradient norm: {}".format(
bf.rank(), global_grad_norm))

# the norm of local gradient is expected not be be close to 0
# this is because each rank converges to global solution, not local solution
local_grad_norm = torch.norm(A.T.mm(A.mm(x_opt) - b), p=2)
print("[DG] Rank {}: local gradient norm: {}".format(bf.rank(), local_grad_norm))

# ==================== Exact Diffusion ===========================================
# Calculate the true solution with exact diffusion recursion as follows:
Expand All @@ -87,7 +78,7 @@ def finalize_plot():
# phi^{k+1} = psi^{k+1} + w^k - psi^{k}
# w^{k+1} = neighbor_allreduce(phi^{k+1})
#
# References:
# Reference:
#
# [R1] K. Yuan, B. Ying, X. Zhao, and A. H. Sayed, ``Exact diffusion for distributed
# optimization and learning -- Part I: Algorithm development'', 2018. (Alg. 1)
Expand All @@ -96,35 +87,31 @@ def finalize_plot():
# [R2] Z. Li, W. Shi and M. Yan, ``A Decentralized Proximal-gradient Method with
# Network Independent Step-sizes and Separated Convergence Rates'', 2019
# ================================================================================
if args.method == 0:
def exact_diffusion(w_opt, maxite=2000, alpha_ed=1e-2, use_Abar=False):

x = torch.zeros(n, 1).to(torch.double)
phi, psi, psi_prev = x.clone(), x.clone(), x.clone()
alpha_ed = 1e-2 # step-size for exact diffusion
mse = []

topology = bf.load_topology()
self_weight, neighbor_weights = topology_util.GetWeights(topology, bf.rank())

# construct A_bar
if use_Abar:
self_weight = (self_weight+1)/2
for k, v in neighbor_weights.items():
neighbor_weights[k] = v/2

for i in range(maxite):
grad = A.T.mm(A.mm(x)-b) # local gradient
psi = x - alpha_ed * grad
phi = psi + x - psi_prev
x = bf.neighbor_allreduce(phi, name='local variable')
x = bf.neighbor_allreduce(phi, self_weight, neighbor_weights, name='local variable')
psi_prev = psi
if bf.rank() == 0:
mse.append(torch.norm(x - x_opt, p=2))
mse.append(torch.norm(x - w_opt, p=2))

# evaluate the convergence of exact diffuion least-squares
# the norm of global gradient is expected to be 0 (optimality condition)
global_grad_norm = torch.norm(bf.allreduce(A.T.mm(A.mm(x) - b)), p=2)
print("[ED] Rank {}: global gradient norm: {}".format(
bf.rank(), global_grad_norm))

# the norm of local gradient is expected not be be close to 0
# this is because each rank converges to global solution, not local solution
local_grad_norm = torch.norm(A.T.mm(A.mm(x) - b), p=2)
print("[ED] Rank {}: local gradient norm: {}".format(
bf.rank(), local_grad_norm))

if bf.rank() == 0:
plt.semilogy(mse)
finalize_plot()
return x, mse

# ======================= gradient tracking =====================================
# Calculate the true solution with gradient tracking (GT for short):
Expand All @@ -133,7 +120,7 @@ def finalize_plot():
# q^{k+1} = neighbor_allreduce(q^k) + grad(w^{k+1}) - grad(w^k)
# where q^0 = grad(w^0)
#
# References:
# Reference:
# [R1] A. Nedic, A. Olshevsky, and W. Shi, ``Achieving geometric convergence
# for distributed optimization over time-varying graphs'', 2017. (Alg. 1)
#
Expand All @@ -146,12 +133,10 @@ def finalize_plot():
# [R4] P. Di Lorenzo and G. Scutari, ``Next: In-network nonconvex optimization'',
# 2016
# ================================================================================
if args.method == 1:
def gradient_tracking(w_opt, maxite=2000, alpha_gt=1e-2):
x = torch.zeros(n, 1).to(torch.double)
y = A.T.mm(A.mm(x)-b)
grad_prev = y.clone()
# step-size for GT (should be smaller than exact diffusion)
alpha_gt = 5e-3
mse_gt = []
for i in range(maxite):
x_handle = bf.neighbor_allreduce_async(x, name='Grad.Tracking.x')
Expand All @@ -162,23 +147,9 @@ def finalize_plot():
y = bf.synchronize(y_handle) + grad - grad_prev
grad_prev = grad
if bf.rank() == 0:
mse_gt.append(torch.norm(x - x_opt, p=2))
mse_gt.append(torch.norm(x - w_opt, p=2))

# evaluate the convergence of gradient tracking for least-squares
# the norm of global gradient is expected to be 0 (optimality condition)
global_grad_norm = torch.norm(bf.allreduce(A.T.mm(A.mm(x) - b)), p=2)
print("[GT] Rank {}: global gradient norm: {}".format(
bf.rank(), global_grad_norm))

# the norm of local gradient is expected not be be close to 0
# this is because each rank converges to global solution, not local solution
local_grad_norm = torch.norm(A.T.mm(A.mm(x) - b), p=2)
print("[GT] Rank {}: local gradient norm: {}".format(
bf.rank(), local_grad_norm))

if bf.rank() == 0:
plt.semilogy(mse_gt)
finalize_plot()
return x, mse_gt

# ======================= Push-DIGing for directed graph =======================
# Calculate the true solution with Push-DIGing:
Expand All @@ -188,7 +159,8 @@ def finalize_plot():
# [R1] A. Nedic, A. Olshevsky, and W. Shi, ``Achieving geometric convergence
# for distributed optimization over time-varying graphs'', 2017. (Alg. 2)
# ============================================================================
if args.method == 2:
def push_diging(w_opt, maxite=2000, alpha_pd = 1e-2):

bf.set_topology(topology_util.PowerTwoRingGraph(bf.size()))
outdegree = len(bf.out_neighbor_ranks())
indegree = len(bf.in_neighbor_ranks())
Expand All @@ -202,46 +174,79 @@ def finalize_plot():

bf.win_create(w, name="w_buff", zero_init=True)

# step-size for Push-DIGing (should be smaller than exact diffusion)
alpha_pd = 1e-2
mse_pd = []
maxite = 1000
for i in range(maxite):
if i % 10 == 0:
bf.barrier()

w[:n] = w[:n] - alpha_pd*w[n:2*n]
bf.win_accumulate(
w, name="w_buff",
dst_weights={rank: 1.0 / (outdegree + 1)
dst_weights={rank: 1.0 / (outdegree*2)
for rank in bf.out_neighbor_ranks()},
require_mutex=True)
w.div_(1+outdegree)
w.div_(2)
w = bf.win_update_then_collect(name="w_buff")

x = w[:n]/w[-1]
grad = A.T.mm(A.mm(x)-b)
w[n:2*n] += grad - grad_prev
grad_prev = grad
if bf.rank() == 0:
mse_pd.append(torch.norm(x - x_opt, p=2))
mse_pd.append(torch.norm(x - w_opt, p=2))

bf.barrier()
w = bf.win_update_then_collect(name="w_buff")
x = w[:n]/w[-1]

# evaluate the convergence of gradient tracking for least-squares
return x, mse_pd

# ======================= Code starts here =======================
bf.init()

# Generate data
# y = A@x + ns where ns is Gaussion noise
torch.random.manual_seed(123417 * bf.rank())
m, n = 20, 5
A = torch.randn(m, n).to(torch.double)
x_o = torch.randn(n, 1).to(torch.double)
ns = 0.1*torch.randn(m, 1).to(torch.double)
b = A.mm(x_o) + ns

# calculate the global solution w_opt via distributed gradient descent
w_opt = distributed_grad_descent()


# solve the logistic regression with indicated decentralized algorithms
if args.method == 'exact_diffusion':
w, mse = exact_diffusion(w_opt)
elif args.method == 'gradient_tracking':
w, mse = gradient_tracking(w_opt, alpha_gt=5e-3)
elif args.method == 'push_diging':
w, mse = push_diging(w_opt, alpha_pd=5e-3)

# plot and print result
try:
if bf.rank() == 0:
plt.semilogy(mse)
finalize_plot()

# calculate local and global gradient
grad = torch.norm(bf.allreduce(A.T.mm(A.mm(w) - b)), p=2) # global gradient

# evaluate the convergence of gradient tracking for logistic regression
# the norm of global gradient is expected to be 0 (optimality condition)
global_grad_norm = torch.norm(bf.allreduce(A.T.mm(A.mm(x) - b)), p=2)
print("[PD] Rank {}: global gradient norm: {}".format(
bf.rank(), global_grad_norm))
global_grad_norm = torch.norm(grad, p=2)
print("[{}] Rank {}: global gradient norm: {}".format(
args.method, bf.rank(), global_grad_norm))

# the norm of local gradient is expected not be be close to 0
# the norm of local gradient is expected not to be close to 0
# this is because each rank converges to global solution, not local solution
local_grad_norm = torch.norm(A.T.mm(A.mm(x) - b), p=2)
print("[PD] Rank {}: local gradient norm: {}".format(
bf.rank(), local_grad_norm))
local_grad_norm = torch.norm(A.T.mm(A.mm(w) - b), p=2)
print("[{}] Rank {}: local gradient norm: {}".format(
args.method, bf.rank(), local_grad_norm))

except NameError:
if bf.rank() == 0:
plt.semilogy(mse_pd)
finalize_plot()
print('Algorithm not support. This example only supports' \
+ ' exact_diffusion, gradient_tracking, and push_diging')
4 changes: 2 additions & 2 deletions examples/pytorch_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def logistic_loss_step(x_, rho, tensor_name):
# # x^{k+1} = x^k - alpha * allreduce(local_grad)
# # it will be used to verify the solution of various decentralized algorithms.
# # ================================================================================
def distributed_grad_descent(maxite = 5000, alpha = 1e-1):
def distributed_grad_descent(rho, maxite = 5000, alpha = 1e-1):
w_opt = torch.zeros(n, 1, dtype=torch.double, requires_grad=True)

for i in range(maxite):
Expand Down Expand Up @@ -266,7 +266,7 @@ def push_diging(w_opt, rho, maxite=2000, alpha_pd = 1e-1):
rho = 1e-2

# calculate the global solution w_opt via distributed gradient descent
w_opt = distributed_grad_descent()
w_opt = distributed_grad_descent(rho)

# solve the logistic regression with indicated decentralized algorithms
if args.method == 'exact_diffusion':
Expand Down

0 comments on commit a375fe8

Please sign in to comment.