In [5]:
from NeuralNetwork import NeuralNetwork, train_nn, performance, split
from _functions import *
import numpy as np
from time import time
import shutil

In [6]:
n = 353
inc = 1
while n < 600:
    print(n)
    
    # Split the dataset into n subsets
    split("MNIST/mnist_train.csv", split_amount = n)
    
    start = time()
    
    ###
    ### Train all the clients on the datasets
    ###
    
    # number of input, hidden and output nodes
    input_nodes = 784
    hidden_nodes = 200
    output_nodes = 10
    learning_rate = 0.1

    dataset_count = n
    for i in range(0, dataset_count):

        # Create the neural network object
        nn = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)

        # Load the server model
        nn.loadModel("models/server")

        # Possibly train a neural network
        train_nn(nn, data_set = "dataset" + str(i) + ".csv", epochs = 5)

        # Save the model
        nn.saveModel("models/mnist")
    
    
    ### 
    ### Combine all the models
    ###
    

    # Get a list of all the client's models
    client_models = get_dirlist("models")
    client_models.remove("server")

    # Find the "servers" weights
    # This, by default is the first trained model that a client spits out
    serverWIH= np.load("models/" + client_models[0] +"/wih.npy")
    serverWHO = np.load("models/" + client_models[0]+ "/who.npy")

    totalAcc = 0
    client_accuracies = ""

    # Combine/Merge all the models
    for model in client_models:

        clientWIH = np.load("models/" + model +"/wih.npy")
        clientWHO = np.load("models/" + model +"/who.npy")

        serverWIH = np.mean( np.array([ serverWIH, clientWIH ]), axis=0 )
        serverWHO = np.mean( np.array([ serverWHO, clientWHO ]), axis=0 )

        # number of input, hidden and output nodes
        input_nodes = 784
        hidden_nodes = 200
        output_nodes = 10
        learning_rate = 0.1

        # Create the neural network object
        nn = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
        nn.update_info(serverWIH, serverWHO)

        perf = performance(nn)
        client_accuracies = client_accuracies + (str(perf) + ", ")
        totalAcc += perf
        
    performance_of_batch = round(totalAcc / len(client_models), ndigits=3)
    time_of_batch = str(round((time() - start) / len(client_models), ndigits = 3))
    
    stats = "========\nSplit size: " + str(n) + "\nAccuracy: " + str(performance_of_batch) + "%\nClients Accuracy: " + str(client_accuracies) + "\nTime: " + str(time_of_batch) + " seconds (estimate per client)\n========\n\n"
    with open("split_scores.txt", "a") as splitscores:
        splitscores.write(stats)
    
    ###
    ### Cleaning up the models (deleting them)
    ###
    for model in client_models:
        shutil.rmtree("models/" + model)
    
    if n % 10 == 0:
        inc += 1
    n += inc

353
354
355
356
357
358
359
360
362
364
366
368
370
373
376
379
382
385
388
391
394
397
400
404
408
412
416
420
425
430
436
442
448
454
460
467
474
481
488
495
502
509
516
523
530
538
546
554
562
570
579
588
597
