In [None]:
import tensorflow as tf
from mpi4py import MPI
import numpy as np

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")


def create_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

def train(model, x_train, y_train, rank, size):
    # Split data across nodes
    n = len(x_train)
    chunk_size = n // size
    start = rank * chunk_size
    end = n if rank == size - 1 else (rank + 1) * chunk_size

    x_train_chunk = x_train[start:end]
    y_train_chunk = y_train[start:end]

    
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    
    model.fit(x_train_chunk, y_train_chunk, epochs=1, batch_size=32, verbose=0)

    
    train_loss, train_acc = model.evaluate(x_train_chunk, y_train_chunk, verbose=0)

    train_acc = comm.allreduce(train_acc, op=MPI.SUM) / size
    return train_acc


model = create_model()
epochs = 3

for epoch in range(epochs):

    train_acc = train(model, x_train, y_train, rank, size)

    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
    test_acc = comm.allreduce(test_acc, op=MPI.SUM) / size

    if rank == 0:
        print(f"Epoch {epoch + 1}: Train accuracy = {train_acc:.4f}, Test accuracy = {test_acc:.4f}")


Epoch 1: Train accuracy = 0.9756, Test accuracy = 0.9725
Epoch 2: Train accuracy = 0.9845, Test accuracy = 0.9797
Epoch 3: Train accuracy = 0.9872, Test accuracy = 0.9790
