In [None]:
!pip install git+https://github.com/konstmish/opt_methods.git

In [None]:
!pip install -U ray

HBM_Scaffnew and other Algorithms

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import ray

###########################################
# Common settings and helper functions
###########################################
d = 5               # dimension of model
num_clients = 10     # number of clients
T_rounds = 300     # number of communication rounds

# For our simulation we use quadratic losses: f_i(x) = 0.5*||x - b_i||^2.
# Its gradient is: grad f_i(x) = x - b_i.
np.random.seed(42)
client_b = [np.random.randn(d) for _ in range(num_clients)]
b_avg = np.mean(np.stack(client_b, axis=0), axis=0)

def make_grad_func(b):
    """Return a gradient function for a quadratic: grad f_i(x) = x - b."""
    return lambda x: x - b

def global_grad_norm(x_avg):
    """Global gradient norm: ||x_avg - b_avg||."""
    return np.linalg.norm(x_avg - b_avg)

###########################################
# 1. Scaffold (Centralized simulation)
###########################################
class Scaffold:
    def __init__(self, client_grad_funcs, x0_clients, gamma):
        self.n = len(client_grad_funcs)
        self.grad_funcs = client_grad_funcs
        self.x = [x0.copy() for x0 in x0_clients]  # local models
        self.c = [np.zeros_like(x0_clients[0]) for _ in range(self.n)]
        self.c_server = np.mean(np.array(self.c), axis=0)
        self.gamma = gamma

    def round(self):
        new_x = []
        for i in range(self.n):
            grad_i = self.grad_funcs[i](self.x[i])
            x_new = self.x[i] - self.gamma * (grad_i - self.c[i] + self.c_server)
            new_x.append(x_new)
        for i in range(self.n):
            grad_new = self.grad_funcs[i](new_x[i])
            grad_old = self.grad_funcs[i](self.x[i])
            self.c[i] = self.c[i] + (grad_new - grad_old)
        self.c_server = np.mean(np.array(self.c), axis=0)
        self.x = new_x

    def run(self, rounds):
        trajectory = []
        for t in range(rounds):
            self.round()
            x_avg = np.mean(np.stack(self.x, axis=0), axis=0)
            trajectory.append(x_avg.copy())
        return trajectory

###########################################
# 2. Scaffnew (ProxSkip for Federated Learning using Ray)
###########################################
@ray.remote
class ScaffnewClient:
    def __init__(self, x0, h0, grad_func):
        self.x = x0.copy()
        self.h = h0.copy()
        self.grad_func = grad_func

    def local_update(self, gamma):
        grad = self.grad_func(self.x)
        x_hat = self.x - gamma * (grad - self.h)
        return x_hat

    def apply_update(self, x_hat, x_comm, do_comm, gamma, p):
        if do_comm:
            new_x = x_comm
        else:
            new_x = x_hat
        self.h = self.h + (p / gamma) * (new_x - x_hat)
        self.x = new_x

    def get_model(self):
        return self.x

def run_scaffnew(T, gamma, p, d, client_b):
    clients = []
    x0 = np.zeros(d)
    h0 = np.zeros(d)
    for i in range(num_clients):
        grad_func = make_grad_func(client_b[i])
        client = ScaffnewClient.remote(x0, h0, grad_func)
        clients.append(client)
    traj = []
    for t in range(T):
        futures = [client.local_update.remote(gamma) for client in clients]
        x_hats = ray.get(futures)
        do_comm = random.random() < p
        if do_comm:
            x_comm = np.mean(np.stack(x_hats, axis=0), axis=0)
        else:
            x_comm = None
        update_futures = [client.apply_update.remote(x_hats[i], x_comm, do_comm, gamma, p)
                          for i, client in enumerate(clients)]
        ray.get(update_futures)
        models = ray.get([client.get_model.remote() for client in clients])
        x_avg = np.mean(np.stack(models, axis=0), axis=0)
        traj.append(x_avg.copy())
    return traj

###########################################
# 3. FedLin (Federated Linearized Method using Ray)
###########################################
@ray.remote
class FedLinClient:
    def __init__(self, x0, grad_func, local_steps):
        self.x = x0.copy()
        self.grad_func = grad_func
        self.local_steps = local_steps

    def local_update(self, gamma):
        for _ in range(self.local_steps):
            grad = self.grad_func(self.x)
            self.x = self.x - gamma * grad
        return self.x

    def set_model(self, new_x):
        self.x = new_x.copy()

    def get_model(self):
        return self.x

def run_fedlin(T, gamma, local_steps, d, client_b):
    clients = []
    for i in range(num_clients):
        x0 = np.zeros(d)
        grad_func = make_grad_func(client_b[i])
        client = FedLinClient.remote(x0, grad_func, local_steps)
        clients.append(client)
    traj = []
    for t in range(T):
        futures = [client.local_update.remote(gamma) for client in clients]
        models = ray.get(futures)
        x_avg = np.mean(np.stack(models, axis=0), axis=0)
        update_futures = [client.set_model.remote(x_avg) for client in clients]
        ray.get(update_futures)
        traj.append(x_avg.copy())
    return traj

###########################################
# 4. LocalGD (Standard local gradient descent using Ray)
###########################################
@ray.remote
class LocalGDClient:
    def __init__(self, x0, grad_func, local_steps):
        """
        Args:
            x0 (np.ndarray): Initial model.
            grad_func: Function computing ∇fᵢ(x) for this client.
            local_steps (int): Number of local gradient steps per round.
        """
        self.x = x0.copy()
        self.grad_func = grad_func
        self.local_steps = local_steps

    def local_update(self, gamma):
        """Perform local_steps of gradient descent without control variates."""
        for _ in range(self.local_steps):
            grad = self.grad_func(self.x)
            self.x = self.x - gamma * grad
        return self.x

    def set_model(self, new_x):
        self.x = new_x.copy()

    def get_model(self):
        return self.x

def run_local_gd(T, gamma, local_steps, d, client_b):
    """
    Runs LocalGD for T rounds.
    Args:
        T: Number of rounds.
        gamma: Local stepsize.
        local_steps: Number of local steps per round.
        d: Model dimension.
        client_b: List of b vectors for each client.
    Returns:
        traj: List of aggregated models per round.
    """
    clients = []
    for i in range(num_clients):
        x0 = np.zeros(d)
        grad_func = make_grad_func(client_b[i])
        client = LocalGDClient.remote(x0, grad_func, local_steps)
        clients.append(client)
    traj = []
    for t in range(T):
        futures = [client.local_update.remote(gamma) for client in clients]
        models = ray.get(futures)
        x_avg = np.mean(np.stack(models, axis=0), axis=0)
        update_futures = [client.set_model.remote(x_avg) for client in clients]
        ray.get(update_futures)
        traj.append(x_avg.copy())
    return traj
###########################################
# 5. Scaffnew with Heavy-Ball Momentum (Ray)
###########################################
@ray.remote
class ScaffnewMomentumClient:
    def __init__(self, x0, h0, grad_func, beta):
        self.x = x0.copy()
        self.x_prev = x0.copy()  # For momentum
        self.h = h0.copy()
        self.grad_func = grad_func
        self.beta = beta

    def local_update(self, gamma):
        grad = self.grad_func(self.x)
        momentum = self.x - self.x_prev
        x_hat = self.x - gamma * (grad - self.h) + self.beta * momentum
        return x_hat

    def apply_update(self, x_hat, x_comm, do_comm, gamma, p):
        if do_comm:
            new_x = x_comm
        else:
            new_x = x_hat
        self.h = self.h + (p / gamma) * (new_x - x_hat)
        self.x_prev = self.x.copy()  # Save current before updating
        self.x = new_x

    def get_model(self):
        return self.x

def run_scaffnew_momentum(T, gamma, p, beta, d, client_b):
    clients = []
    x0 = np.zeros(d)
    h0 = np.zeros(d)
    for i in range(num_clients):
        grad_func = make_grad_func(client_b[i])
        client = ScaffnewMomentumClient.remote(x0, h0, grad_func, beta)
        clients.append(client)

    traj = []
    for t in range(T):
        x_hats = ray.get([client.local_update.remote(gamma) for client in clients])
        do_comm = random.random() < p
        if do_comm:
            x_comm = np.mean(np.stack(x_hats, axis=0), axis=0)
        else:
            x_comm = None
        ray.get([
            client.apply_update.remote(x_hats[i], x_comm, do_comm, gamma, p)
            for i, client in enumerate(clients)
        ])
        models = ray.get([client.get_model.remote() for client in clients])
        x_avg = np.mean(np.stack(models, axis=0), axis=0)
        traj.append(x_avg.copy())
    return traj
###########################################
# 6. Scaffold -momentum
###########################################
class ScaffoldMomentum:
    def __init__(self, client_grad_funcs, x0_clients, gamma, beta):
        self.n = len(client_grad_funcs)
        self.grad_funcs = client_grad_funcs
        self.x = [x0.copy() for x0 in x0_clients]               # local models
        self.x_prev = [x0.copy() for x0 in x0_clients]          # previous local models for momentum
        self.c = [np.zeros_like(x0_clients[0]) for _ in range(self.n)]
        self.c_server = np.mean(np.array(self.c), axis=0)
        self.gamma = gamma
        self.beta = beta

    def round(self):
        new_x = []
        for i in range(self.n):
            grad_i = self.grad_funcs[i](self.x[i])
            momentum = self.x[i] - self.x_prev[i]
            x_new = self.x[i] - self.gamma * (grad_i - self.c[i] + self.c_server) + self.beta * momentum
            new_x.append(x_new)

        for i in range(self.n):
            grad_new = self.grad_funcs[i](new_x[i])
            grad_old = self.grad_funcs[i](self.x[i])
            self.c[i] = self.c[i] + (grad_new - grad_old)

        self.c_server = np.mean(np.array(self.c), axis=0)

        # Update x_prev before overwriting x
        for i in range(self.n):
            self.x_prev[i] = self.x[i].copy()
        self.x = new_x

    def run(self, rounds):
        trajectory = []
        for t in range(rounds):
            self.round()
            x_avg = np.mean(np.stack(self.x, axis=0), axis=0)
            trajectory.append(x_avg.copy())
        return trajectory





###########################################
# Main: Run all algorithms and plot grad norm vs iteration
###########################################
# Parameters:
gamma_scaffold = 0.08
gamma_scaffnew = 0.08
p_scaffnew = 0.7
gamma_fedlin = 0.03
local_steps_fedlin = 1
gamma_localgd = 0.02
local_steps_localgd = 1

# Run Scaffold (centralized simulation)
x0_clients = [np.zeros(d) for _ in range(num_clients)]
grad_funcs_scaffold = [make_grad_func(b) for b in client_b]
scaffold_solver = Scaffold(grad_funcs_scaffold, x0_clients, gamma_scaffold)
traj_scaffold = scaffold_solver.run(T_rounds)
grad_norms_scaffold = [global_grad_norm(x_avg) for x_avg in traj_scaffold]

# Run Scaffnew:
traj_scaffnew = run_scaffnew(T_rounds, gamma_scaffnew, p_scaffnew, d, client_b)
grad_norms_scaffnew = [global_grad_norm(x_avg) for x_avg in traj_scaffnew]

# Run FedLin:
traj_fedlin = run_fedlin(T_rounds, gamma_fedlin, local_steps_fedlin, d, client_b)
grad_norms_fedlin = [global_grad_norm(x_avg) for x_avg in traj_fedlin]

# Run LocalGD:
traj_localgd = run_local_gd(T_rounds, gamma_localgd, local_steps_localgd, d, client_b)
grad_norms_localgd = [global_grad_norm(x_avg) for x_avg in traj_localgd]

# Run Scaffnew-Momentum:
#beta_momentum = 0.4  # Momentum parameter
beta_momentum = 0.5  # Momentum parameter
traj_scaffnew_momentum = run_scaffnew_momentum(
    T_rounds, gamma_scaffnew, p_scaffnew, beta_momentum, d, client_b
)
grad_norms_scaffnew_momentum = [global_grad_norm(x_avg) for x_avg in traj_scaffnew_momentum]

# Run Scaffold-Momentum:
#beta_momentum = 0.5  # Momentum parameter
#scaffold_momentum_solver = ScaffoldWithMomentum(
 #   client_grad_funcs=grad_funcs_scaffold,
  #  x0_clients=x0_clients,
   # gamma=gamma_scaffold,
    #beta=beta_momentum)
#traj_scaffold_momentum = scaffold_momentum_solver.run(T_rounds)
# Compute gradient norms over time
#grad_norms_scaffold_momentum = [global_grad_norm(x_avg) for x_avg in traj_scaffold_momentum]

# Plot all curves including Scaffnew-Momentum:
iterations = list(range(T_rounds))
plt.figure(figsize=(10, 6))
plt.plot(iterations, grad_norms_scaffold, marker='o', label='Scaffold')
plt.plot(iterations, grad_norms_scaffnew, marker='s', label='Scaffnew')
plt.plot(iterations, grad_norms_scaffnew_momentum, marker='x', label='Scaffnew-Momentum')
plt.plot(iterations, grad_norms_fedlin, marker='^', label='FedLin')
plt.plot(iterations, grad_norms_localgd, marker='d', label='LocalGD')
plt.xlabel('Communication Rounds ')
plt.ylabel('Global Gradient Norm')
plt.title('Gradient Norm vs Communication Rounds for Scaffold, Scaffnew, Scaffnew-Momentum, FedLin, and LocalGD')
plt.legend()
plt.grid(True)
plt.show()

# Shutdown Ray after all computations.
ray.shutdown()


Gradient norm convergence of different beta values

In [None]:

# Beta values in decreasing order
beta_values = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
scaffnew_momentum_results = {}

# Run for each beta value
for beta in beta_values:
    print(f"Running with beta = {beta}")
    traj = run_scaffnew_momentum(
        T_rounds, gamma, p_scaffnew, beta, d, client_b
    )
    grad_norms = [global_grad_norm(x_avg) for x_avg in traj]
    scaffnew_momentum_results[beta] = grad_norms

# Plotting
plt.figure(figsize=(10, 6))
iterations = list(range(T_rounds))
plt.plot(iterations, scaffnew_momentum_results)

# Plot each beta's results
#for beta in beta_values:
 #   plt.plot(iterations, scaffnew_momentum_results[beta],
  #           label=f'β = {beta}')

#plt.xlabel('Communication Rounds')
#plt.ylabel('Global Gradient Norm')
#plt.title('Scaffnew with Heavy-Ball Momentum for Different β Values')
#plt.legend()
#plt.grid(True)
#plt.show()

Gradient norm convergence of different beta values

In [None]:
# List of beta values to try
beta_values = [0.0, 0.2, 0.4, 0.6, 0.8]
scaffnew_momentum_results = {}

# Run for each beta value
for beta in beta_values:
    traj = run_scaffnew_momentum(
        T_rounds, gamma_scaffnew, p_scaffnew, beta, d, client_b
    )
    grad_norms = [global_grad_norm(x_avg) for x_avg in traj]
    scaffnew_momentum_results[beta] = grad_norms

# Plotting
plt.figure(figsize=(10, 6))
iterations = list(range(T_rounds))

for beta, grad_norms in scaffnew_momentum_results.items():
    label = f'Scaffnew-Momentum (β={beta})'
    plt.plot(iterations, grad_norms, label=label)

plt.xlabel('Communication Rounds ')
plt.ylabel('Global Gradient Norm ')
plt.title('Scaffnew with Heavy-Ball Momentum for Different β Values')
plt.legend()
plt.grid(True)
plt.show()


Addaptive beta HBM_Scaffnew

In [None]:
@ray.remote
class ScaffnewMomentumClient:
    def __init__(self, x0, h0, grad_func, beta):
        self.x = x0.copy()
        self.x_prev = x0.copy()
        self.h = h0.copy()
        self.grad_func = grad_func
        self.beta = beta

    def local_update(self, gamma):
        grad = self.grad_func(self.x)
        momentum = self.x - self.x_prev
        x_hat = self.x - gamma * (grad - self.h) + self.beta * momentum
        return x_hat

    def apply_update(self, x_hat, x_comm, do_comm, gamma, p):
        if do_comm:
            new_x = x_comm
        else:
            new_x = x_hat
        self.h = self.h + (p / gamma) * (new_x - x_hat)
        self.x_prev = self.x.copy()
        self.x = new_x

    def get_model(self):
        return self.x

    def set_beta(self, new_beta):
        self.beta = new_beta

    def set_x(self, new_x):
        self.x = new_x.copy()
def run_scaffnew_momentum_adaptive(T, gamma, p, d, client_b):
    clients = []
    x0 = np.zeros(d)
    h0 = np.zeros(d)
    initial_beta = 1.0
    beta = initial_beta
    decay_rate = 0.065
    min_beta = 0.0

    for i in range(num_clients):
        grad_func = make_grad_func(client_b[i])
        client = ScaffnewMomentumClient.remote(x0, h0, grad_func, beta)
        clients.append(client)

    traj = []
    grad_norms = []
    beta_values = []

    for t in range(T):
        # Save current state
        original_states = ray.get([client.get_model.remote() for client in clients])

        # Try decayed beta
        beta_candidate = max(min_beta, beta - decay_rate)
        ray.get([client.set_beta.remote(beta_candidate) for client in clients])

        x_hats_trial = ray.get([client.local_update.remote(gamma) for client in clients])
        do_comm_trial = random.random() < p
        x_comm_trial = np.mean(np.stack(x_hats_trial, axis=0), axis=0) if do_comm_trial else None

        ray.get([
            client.apply_update.remote(x_hats_trial[i], x_comm_trial, do_comm_trial, gamma, p)
            for i, client in enumerate(clients)
        ])
        models_trial = ray.get([client.get_model.remote() for client in clients])
        x_avg_trial = np.mean(np.stack(models_trial, axis=0), axis=0)
        grad_trial = global_grad_norm(x_avg_trial)

        if t == 0 or grad_trial < grad_norms[-1]:
            beta = beta_candidate
            traj.append(x_avg_trial.copy())
            grad_norms.append(grad_trial)
            beta_values.append(beta)
        else:
            # Revert to previous x and beta
            ray.get([client.set_beta.remote(beta) for client in clients])
            ray.get([client.set_x.remote(original_states[i]) for i, client in enumerate(clients)])

            x_hats = ray.get([client.local_update.remote(gamma) for client in clients])
            do_comm = random.random() < p
            x_comm = np.mean(np.stack(x_hats, axis=0), axis=0) if do_comm else None

            ray.get([
                client.apply_update.remote(x_hats[i], x_comm, do_comm, gamma, p)
                for i, client in enumerate(clients)
            ])
            models = ray.get([client.get_model.remote() for client in clients])
            x_avg = np.mean(np.stack(models, axis=0), axis=0)
            traj.append(x_avg.copy())
            grad_norms.append(global_grad_norm(x_avg))
            beta_values.append(beta)

    return traj, grad_norms, beta_values
# Run the adaptive training
traj_adaptive, grad_norms_adaptive, beta_values = run_scaffnew_momentum_adaptive(
    T_rounds, gamma_scaffnew, p_scaffnew, d, client_b
)

iterations = list(range(T_rounds))

# Plot gradient norm and beta
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(iterations, grad_norms_adaptive, marker='x', label='Gradient Norm')
plt.xlabel('Communication Rounds')
plt.ylabel('Global Gradient Norm')
plt.title('Gradient Norm vs Iteration')
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(iterations, beta_values, marker='o', color='orange', label='Momentum β')
plt.xlabel('Communication Rounds')
plt.ylabel('β Value')
plt.title('Momentum Parameter (β) over Time')
plt.grid(True)
plt.legend()

plt.suptitle('Scaffnew with Adaptive Heavy-Ball Momentum')
plt.tight_layout()
plt.show()


Loss Curves for One beta and Adaptive Beta

In [None]:
# Step 1: Run fixed-beta version (already done)
traj_fixed = run_scaffnew_momentum(
    T_rounds, gamma_scaffnew, p_scaffnew, beta_momentum, d, client_b
)
grad_norms_fixed = [global_grad_norm(x_avg) for x_avg in traj_fixed]

# Step 2: Run adaptive-beta version
traj_adaptive, grad_norms_adaptive, beta_values = run_scaffnew_momentum_adaptive(
    T_rounds, gamma_scaffnew, p_scaffnew, d, client_b
)

# Step 3: Plot both
iterations = list(range(T_rounds))
plt.figure(figsize=(10, 6))

plt.plot(iterations, grad_norms_fixed, marker='o', label=f'Fixed β = {beta_momentum}')
plt.plot(iterations, grad_norms_adaptive, marker='x', label='Adaptive β (1→0)')

plt.xlabel('Communication Rounds ')
plt.ylabel('Global Gradient Norm')
plt.title('Scaffnew with Heavy-Ball Momentum: Fixed vs Adaptive β')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


Loss Curve for different Fixed beta and Adaptive Beta

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import ray
import urllib.request
import os
from sklearn.datasets import load_svmlight_file
from sklearn.metrics import log_loss

#############################
# 1. Download & Load Dataset
#############################
def download_and_load_dataset(dataset_name="w8a"):
    url = f"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/{dataset_name}"
    if not os.path.exists(dataset_name):
        urllib.request.urlretrieve(url, dataset_name)
    X_sparse, y = load_svmlight_file(dataset_name)
    X = X_sparse.toarray()
    y = (y > 0).astype(int)  # Convert to binary {0,1}
    return X, y

#############################
# 2. Gradient Function
#############################
def make_grad_func(Xb, yb):
    def grad_fn(w):
        preds = 1 / (1 + np.exp(-Xb @ w))
        grad = Xb.T @ (preds - yb) / len(yb)
        return grad
    return grad_fn

#############################
# 3. Logistic Loss
#############################
def compute_logistic_loss(X, y, w):
    logits = X @ w
    preds = 1 / (1 + np.exp(-logits))
    return log_loss(y, preds)

#############################
# 4. Ray Client
#############################
@ray.remote
class ScaffnewMomentumClient:
    def __init__(self, x0, h0, grad_func, beta):
        self.x = x0.copy()
        self.x_prev = x0.copy()
        self.h = h0.copy()
        self.grad_func = grad_func
        self.beta = beta

    def local_update(self, gamma):
        grad = self.grad_func(self.x)
        momentum = self.x - self.x_prev
        x_hat = self.x - gamma * (grad - self.h) + self.beta * momentum
        return x_hat

    def apply_update(self, x_hat, x_comm, do_comm, gamma, p):
        new_x = x_comm if do_comm else x_hat
        self.h = self.h + (p / gamma) * (new_x - x_hat)
        self.x_prev = self.x.copy()
        self.x = new_x

    def get_model(self):
        return self.x

    def set_beta(self, new_beta):
        self.beta = new_beta

    def set_x(self, new_x):
        self.x = new_x.copy()

#############################
# 5. Fixed Beta Run
#############################
def run_scaffnew_momentum(T, gamma, p, beta, d, client_b):
    clients = []
    x0 = np.zeros(d)
    h0 = np.zeros(d)
    for i in range(len(client_b)):
        Xb, yb = client_b[i]
        grad_func = make_grad_func(Xb, yb)
        client = ScaffnewMomentumClient.remote(x0, h0, grad_func, beta)
        clients.append(client)

    traj = []
    for t in range(T):
        x_hats = ray.get([client.local_update.remote(gamma) for client in clients])
        do_comm = random.random() < p
        x_comm = np.mean(np.stack(x_hats, axis=0), axis=0) if do_comm else None
        ray.get([client.apply_update.remote(x_hats[i], x_comm, do_comm, gamma, p)
                 for i, client in enumerate(clients)])
        models = ray.get([client.get_model.remote() for client in clients])
        x_avg = np.mean(np.stack(models, axis=0), axis=0)
        traj.append(x_avg.copy())
    return traj

#############################
# 6. Adaptive Beta Run
#############################
def run_scaffnew_momentum_adaptive(T, gamma, p, d, client_b):
    clients = []
    x0 = np.zeros(d)
    h0 = np.zeros(d)
    beta = 1.0
    decay_rate = 0.065
    min_beta = 0.0

    for i in range(len(client_b)):
        Xb, yb = client_b[i]
        grad_func = make_grad_func(Xb, yb)
        client = ScaffnewMomentumClient.remote(x0, h0, grad_func, beta)
        clients.append(client)

    traj, grad_norms, beta_values = [], [], []

    for t in range(T):
        original_states = ray.get([client.get_model.remote() for client in clients])

        beta_candidate = max(min_beta, beta - decay_rate)
        ray.get([client.set_beta.remote(beta_candidate) for client in clients])

        x_hats_trial = ray.get([client.local_update.remote(gamma) for client in clients])
        do_comm_trial = random.random() < p
        x_comm_trial = np.mean(np.stack(x_hats_trial, axis=0), axis=0) if do_comm_trial else None

        ray.get([client.apply_update.remote(x_hats_trial[i], x_comm_trial, do_comm_trial, gamma, p)
                 for i, client in enumerate(clients)])
        models_trial = ray.get([client.get_model.remote() for client in clients])
        x_avg_trial = np.mean(np.stack(models_trial, axis=0), axis=0)
        grad_trial = np.linalg.norm(x_avg_trial)

        if t == 0 or grad_trial < grad_norms[-1]:
            beta = beta_candidate
            traj.append(x_avg_trial.copy())
            grad_norms.append(grad_trial)
            beta_values.append(beta)
        else:
            ray.get([client.set_beta.remote(beta) for client in clients])
            ray.get([client.set_x.remote(original_states[i]) for i, client in enumerate(clients)])

            x_hats = ray.get([client.local_update.remote(gamma) for client in clients])
            do_comm = random.random() < p
            x_comm = np.mean(np.stack(x_hats, axis=0), axis=0) if do_comm else None
            ray.get([client.apply_update.remote(x_hats[i], x_comm, do_comm, gamma, p)
                     for i, client in enumerate(clients)])
            models = ray.get([client.get_model.remote() for client in clients])
            x_avg = np.mean(np.stack(models, axis=0), axis=0)
            traj.append(x_avg.copy())
            grad_norms.append(np.linalg.norm(x_avg))
            beta_values.append(beta)

    return traj, grad_norms, beta_values

#############################
# 7. Plot Logistic Loss
#############################
# Function to plot fixed vs adaptive beta loss curves
def plot_loss_vs_rounds(X, y, traj_fixed, traj_adaptive, betas_fixed=None):
    iterations = list(range(len(traj_fixed[0]) if isinstance(traj_fixed[0], list) else len(traj_fixed)))
    plt.figure(figsize=(10, 6))

    # Plot fixed beta runs
    if isinstance(traj_fixed[0], list):  # Multiple beta runs
        for i, traj in enumerate(traj_fixed):
            losses = [compute_logistic_loss(X, y, w) for w in traj]
            label = f"Fixed \u03b2={betas_fixed[i]:.2f}" if betas_fixed else f"Fixed \u03b2-{i}"
            plt.plot(iterations, losses, label=label)
    else:  # Single fixed beta run
        losses_fixed = [compute_logistic_loss(X, y, w) for w in traj_fixed]
        plt.plot(iterations, losses_fixed, label=f"Fixed \u03b2={betas_fixed[0] if betas_fixed else 0.5}")

    # Plot adaptive
    losses_adaptive = [compute_logistic_loss(X, y, w) for w in traj_adaptive]
    plt.plot(iterations, losses_adaptive, label="Adaptive \u03b2", linestyle='--', linewidth=2)

    plt.xlabel("Communication Rounds")
    plt.ylabel("Average Logistic Loss")
    plt.title("Loss vs Communication Rounds (Fixed vs Adaptive \u03b2)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

#############################
# 8. Run Everything
#############################
ray.init(ignore_reinit_error=True)

X, y = download_and_load_dataset("w8a")
d = X.shape[1]
num_clients = 10
T_rounds = 1000
p_scaffnew = 0.7
gamma_scaffnew = 0.08
betas = [0.1, 0.3, 0.5, 0.7, 0.9,1.0]

client_b = []
idx_split = np.array_split(np.arange(X.shape[0]), num_clients)
for idx in idx_split:
    client_b.append((X[idx], y[idx]))

# Run fixed beta versions
traj_fixed_all = []
for beta in betas:
    traj_fixed = run_scaffnew_momentum(T_rounds, gamma_scaffnew, p_scaffnew, beta, d, client_b)
    traj_fixed_all.append(traj_fixed)

# Run adaptive beta version
traj_adaptive, _, _ = run_scaffnew_momentum_adaptive(T_rounds, gamma_scaffnew, p_scaffnew, d, client_b)

# Plot
plot_loss_vs_rounds(X, y, traj_fixed_all, traj_adaptive, betas_fixed=betas)
