In [22]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np


In [23]:
def load_data():
    # Load the MNIST dataset
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    # Flatten and normalize input images
    X_train = X_train.reshape((X_train.shape[0], 28 * 28)) / 255.0
    X_test = X_test.reshape((X_test.shape[0], 28 * 28)) / 255.0
    # Convert the labels to one-hot encoding
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)
    return X_train, y_train, X_test, y_test

In [24]:
def create_model():
    # Create a sequential model
    model = Sequential()
    # Add a fully connected layer
    model.add(Dense(128, activation='relu', input_shape=(28 * 28,)))
    # Add the output layer
    model.add(Dense(10, activation='softmax'))
    # Compile the model with optimizer
    model.compile(optimizer=SGD(learning_rate=0.01), loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [25]:
X_train, y_train, X_test, y_test = load_data()

In [26]:
model = create_model()



In [96]:
workers = 10
rank = 0

old_weights = model.get_weights()
for i in range(100):
  print(model.train_on_batch(x=X_train[rank::workers], y=y_train[rank::workers], return_dict=True))
new_weights = model.get_weights()

print(old_weights[0].shape)
print(old_weights[1].shape)



{'loss': 0.5305126905441284, 'accuracy': 0.8650000095367432}
{'loss': 0.5302083492279053, 'accuracy': 0.8651666641235352}
{'loss': 0.5299046635627747, 'accuracy': 0.8651666641235352}
{'loss': 0.5296016931533813, 'accuracy': 0.8653333187103271}
{'loss': 0.5292994379997253, 'accuracy': 0.8651666641235352}
{'loss': 0.5289977788925171, 'accuracy': 0.8651666641235352}
{'loss': 0.5286967754364014, 'accuracy': 0.8651666641235352}
{'loss': 0.5283964276313782, 'accuracy': 0.8653333187103271}
{'loss': 0.5280967354774475, 'accuracy': 0.8653333187103271}
{'loss': 0.5277977585792542, 'accuracy': 0.8653333187103271}
{'loss': 0.5274993777275085, 'accuracy': 0.8654999732971191}
{'loss': 0.5272017121315002, 'accuracy': 0.8654999732971191}
{'loss': 0.5269046425819397, 'accuracy': 0.8654999732971191}
{'loss': 0.5266082286834717, 'accuracy': 0.8654999732971191}
{'loss': 0.5263125896453857, 'accuracy': 0.8656666874885559}
{'loss': 0.5260173082351685, 'accuracy': 0.8658333420753479}
{'loss': 0.5257228612899